diff --git a/src/toast/CMakeLists.txt b/src/toast/CMakeLists.txt index 7abe61c42..6a0ccf8fe 100644 --- a/src/toast/CMakeLists.txt +++ b/src/toast/CMakeLists.txt @@ -169,6 +169,7 @@ add_subdirectory(io) add_subdirectory(accelerator) add_subdirectory(tests) add_subdirectory(jax) +add_subdirectory(opencl) add_subdirectory(ops) add_subdirectory(templates) add_subdirectory(scripts) diff --git a/src/toast/__init__.py b/src/toast/__init__.py index 7294042f9..908f18b28 100644 --- a/src/toast/__init__.py +++ b/src/toast/__init__.py @@ -38,6 +38,14 @@ * Values "0", "false", or "no" will disable runtime support for hybrid GPU pipelines. * Requires TOAST_GPU_OPENMP or TOAST_GPU_JAX to be enabled. +TOAST_OPENCL= + * Values "1", "true", or "yes" will enable runtime support for pyopencl. + * Requires pyopencl to be available / importable. + +TOAST_OPENCL_DEFAULT= + * Default OpenCL device type, where supported values are "CPU", "GPU", + and "OCLGRIND". + OMP_NUM_THREADS= * Toast uses OpenMP threading in several places and the concurrency is set by the usual environment variable. diff --git a/src/toast/_libtoast/ops_pixels_healpix.cpp b/src/toast/_libtoast/ops_pixels_healpix.cpp index 71dc1e925..0ebd904e9 100644 --- a/src/toast/_libtoast/ops_pixels_healpix.cpp +++ b/src/toast/_libtoast/ops_pixels_healpix.cpp @@ -598,7 +598,8 @@ void pixels_healpix_nest_inner( int64_t n_samp, int64_t idet, uint8_t mask, - bool use_flags + bool use_flags, + bool compute_submaps ) { const double zaxis[3] = {0.0, 0.0, 1.0}; int32_t p_indx = pixel_index[idet]; @@ -618,8 +619,10 @@ void pixels_healpix_nest_inner( if (use_flags && ((flags[isamp] & mask) != 0)) { pixels[poff] = -1; } else { - sub_map = (int64_t)(pixels[poff] / n_pix_submap); - hsub[sub_map] = 1; + if (compute_submaps) { + sub_map = (int64_t)(pixels[poff] / n_pix_submap); + hsub[sub_map] = 1; + } } return; @@ -639,7 +642,9 @@ void pixels_healpix_ring_inner( int64_t n_samp, int64_t idet, uint8_t mask, - bool use_flags) { + bool use_flags, + bool compute_submaps +) { const double zaxis[3] = {0.0, 0.0, 1.0}; int32_t p_indx = pixel_index[idet]; int32_t q_indx = quat_index[idet]; @@ -658,8 +663,10 @@ void pixels_healpix_ring_inner( if (use_flags && ((flags[isamp] & mask) != 0)) { pixels[poff] = -1; } else { - sub_map = (int64_t)(pixels[poff] / n_pix_submap); - hsub[sub_map] = 1; + if (compute_submaps) { + sub_map = (int64_t)(pixels[poff] / n_pix_submap); + hsub[sub_map] = 1; + } } return; @@ -1163,6 +1170,7 @@ void init_ops_pixels_healpix(py::module & m) { int64_t n_pix_submap, int64_t nside, bool nest, + bool compute_submaps, bool use_accel ) { auto & omgr = OmpManager::get(); @@ -1195,10 +1203,14 @@ void init_ops_pixels_healpix(py::module & m) { ); int64_t n_view = temp_shape[0]; + // Optionally compute the hit submaps uint8_t * raw_hsub = extract_buffer ( hit_submaps, "hit_submaps", 1, temp_shape, {-1} ); int64_t n_submap = temp_shape[0]; + if (! compute_submaps) { + raw_hsub = omgr.null_ptr (); + } // Optionally use flags bool use_flags = true; @@ -1225,6 +1237,7 @@ void init_ops_pixels_healpix(py::module & m) { int64_t * dev_pixels = omgr.device_ptr(raw_pixels); Interval * dev_intervals = omgr.device_ptr(raw_intervals); uint8_t * dev_flags = omgr.device_ptr(raw_flags); + uint8_t * dev_hsub = omgr.device_ptr(raw_hsub); // Make sure the lookup table exists on device size_t utab_bytes = 0x100 * sizeof(int64_t); @@ -1258,9 +1271,9 @@ void init_ops_pixels_healpix(py::module & m) { n_det, \ n_samp, \ shared_flag_mask, \ + compute_submaps, \ use_flags \ - ) \ - map(tofrom : raw_hsub[0 : n_submap]) + ) { if (nest) { # pragma omp target teams distribute parallel for collapse(3) \ @@ -1269,6 +1282,7 @@ void init_ops_pixels_healpix(py::module & m) { dev_pixels, \ dev_quats, \ dev_flags, \ + dev_hsub, \ dev_intervals, \ dev_utab \ ) @@ -1293,14 +1307,15 @@ void init_ops_pixels_healpix(py::module & m) { raw_pixel_index, dev_quats, dev_flags, - raw_hsub, + dev_hsub, dev_pixels, n_pix_submap, adjusted_isamp, n_samp, idet, shared_flag_mask, - use_flags + use_flags, + compute_submaps ); } } @@ -1312,6 +1327,7 @@ void init_ops_pixels_healpix(py::module & m) { dev_pixels, \ dev_quats, \ dev_flags, \ + dev_hsub, \ dev_intervals, \ dev_utab \ ) @@ -1335,14 +1351,15 @@ void init_ops_pixels_healpix(py::module & m) { raw_pixel_index, dev_quats, dev_flags, - raw_hsub, + dev_hsub, dev_pixels, n_pix_submap, adjusted_isamp, n_samp, idet, shared_flag_mask, - use_flags + use_flags, + compute_submaps ); } } @@ -1376,7 +1393,8 @@ void init_ops_pixels_healpix(py::module & m) { n_samp, idet, shared_flag_mask, - use_flags + use_flags, + compute_submaps ); } } @@ -1404,7 +1422,8 @@ void init_ops_pixels_healpix(py::module & m) { n_samp, idet, shared_flag_mask, - use_flags + use_flags, + compute_submaps ); } } diff --git a/src/toast/accelerator/__init__.py b/src/toast/accelerator/__init__.py index 46f6497ec..134166937 100644 --- a/src/toast/accelerator/__init__.py +++ b/src/toast/accelerator/__init__.py @@ -15,9 +15,11 @@ accel_data_update_device, accel_data_update_host, accel_enabled, + accel_wait, accel_get_device, use_accel_jax, use_accel_omp, + use_accel_opencl, use_hybrid_pipelines, ) from .kernel_registry import ImplementationType, kernel diff --git a/src/toast/accelerator/accel.py b/src/toast/accelerator/accel.py index 14462a351..11ebf68b4 100644 --- a/src/toast/accelerator/accel.py +++ b/src/toast/accelerator/accel.py @@ -15,6 +15,7 @@ from .._libtoast import accel_reset as omp_accel_reset from .._libtoast import accel_update_device as omp_accel_update_device from .._libtoast import accel_update_host as omp_accel_update_host +from ..opencl import have_opencl, OpenCL from ..timing import function_timer enable_vals = ["1", "yes", "true"] @@ -50,9 +51,22 @@ log.error(msg) raise RuntimeError(msg) -if use_accel_omp and use_accel_jax: +use_accel_opencl = False +if ("TOAST_OPENCL" in os.environ) and (os.environ["TOAST_OPENCL"] in enable_vals): + if not have_opencl: + log = Logger.get() + msg = "TOAST_OPENCL enabled at runtime, but pyopencl is not " + msg += "importable." + log.error(msg) + raise RuntimeError(msg) + use_accel_opencl = True + ocl = OpenCL() + import pyopencl as cl + +if (use_accel_omp + use_accel_jax + use_accel_opencl) > 1: log = Logger.get() - msg = "OpenMP target offload and JAX cannot both be enabled at runtime." + msg = "Only one of OpenMP target offload, JAX, and OpenCL " + msg += "can be enabled at runtime." log.error(msg) raise RuntimeError(msg) @@ -68,7 +82,7 @@ def accel_enabled(): """Returns True if any accelerator support is enabled.""" - return use_accel_jax or use_accel_omp + return use_accel_jax or use_accel_omp or use_accel_opencl def accel_get_device(): @@ -77,6 +91,9 @@ def accel_get_device(): return omp_accel_get_device() elif use_accel_jax: return jax_accel_get_device() + elif use_accel_opencl: + ocl = OpenCL() + return ocl.default_gpu_index() else: log = Logger.get() log.warning("Accelerator support not enabled, returning device -1") @@ -108,6 +125,9 @@ def accel_assign_device(node_procs, node_rank, mem_gb, disabled): omp_accel_assign_device(node_procs, node_rank, mem_gb, disabled) if use_accel_jax: jax_accel_assign_device(node_procs, node_rank, disabled) + elif use_accel_opencl: + ocl = OpenCL() + ocl.assign_default_devices(node_procs, node_rank, disabled) def accel_data_present(data, name="None"): @@ -136,10 +156,20 @@ def accel_data_present(data, name="None"): or isinstance(data, jax.numpy.ndarray) or isinstance(data, INTERVALS_JAX) ) + elif use_accel_opencl: + ocl = OpenCL() + return ocl.mem_present(data, name=name) else: log.warning("Accelerator support not enabled, data not present") return False +def accel_wait(events): + """For some frameworks (OpenCL), wait for events. + """ + if use_accel_opencl: + for ev in events: + if ev is not None: + ev.wait() @function_timer def accel_data_create(data, name="None", zero_out=False): @@ -168,6 +198,12 @@ def accel_data_create(data, name="None", zero_out=False): return MutableJaxArray(cpu_data=data, gpu_data=jax.numpy.zeros_like(data)) else: return MutableJaxArray(data) + elif use_accel_opencl: + ocl = OpenCL() + _ = ocl.mem_create(data, name=name) + if zero_out: + ocl.mem_reset(data, name=name) + return data else: log = Logger.get() log.warning("Accelerator support not enabled, cannot create") @@ -199,6 +235,9 @@ def accel_data_reset(data, name="None"): else: # the data is not on GPU anymore, possibly because it was moved to host data = MutableJaxArray(data, gpu_data=jax.numpy.zeros_like(data)) + elif use_accel_opencl: + ocl = OpenCL() + ocl.mem_reset(data, name=name) else: log = Logger.get() log.warning("Accelerator support not enabled, cannot reset") @@ -218,7 +257,8 @@ def accel_data_update_device(data, name="None"): name (str): The optional name for tracking the array. Returns: - (object): Either the original input (for OpenMP) or a new jax array. + (object): Either the original input (OpenMP), a new jax array, + or a pyopencl Array. """ if use_accel_omp: @@ -226,6 +266,10 @@ def accel_data_update_device(data, name="None"): return data elif use_accel_jax: return MutableJaxArray(data) + elif use_accel_opencl: + ocl = OpenCL() + dev_data = ocl.mem_update_device(data, name=name) + return dev_data else: log = Logger.get() log.warning("Accelerator support not enabled, not updating device") @@ -246,7 +290,8 @@ def accel_data_update_host(data, name="None"): name (str): The optional name for tracking the array. Returns: - (object): Either the updated input (for OpenMP) or a numpy array. + (object): Either the updated input (OpenMP), a numpy array (JAX), + or a pyopencl event. """ if use_accel_omp: @@ -254,6 +299,11 @@ def accel_data_update_host(data, name="None"): return data elif use_accel_jax: return data.to_host() + elif use_accel_opencl: + ocl = OpenCL() + evs = list() + evs.append(ocl.mem_update_host(data, name=name, async_=True)) + return evs else: log = Logger.get() log.warning("Accelerator support not enabled, not updating host") @@ -283,6 +333,9 @@ def accel_data_delete(data, name="None"): # if needed, make sure that data is on host if accel_data_present(data): data = data.host_data + elif use_accel_opencl: + ocl = OpenCL() + ocl.mem_remove(data, name=name) else: log = Logger.get() log.warning("Accelerator support not enabled, cannot delete device data") @@ -301,6 +354,9 @@ def accel_data_table(): omp_accel_dump() elif use_accel_jax: log.debug("Using Jax, skipping dump of OpenMP target device table") + elif use_accel_opencl: + ocl = OpenCL() + ocl.mem_dump() else: log.warning("Accelerator support not enabled, cannot print device table") @@ -314,15 +370,33 @@ class AcceleratorObject(object): add boilerplate and checks in a single place in the code. The internal methods should be overloaded by descendent classes. + This base class internally tracks whether the accelerator or host copy + of the object is "in use". This affects calls to update the device or host + copy. By default, objects are assumed to be in use either on the host or + on the accelerator (in case they are being modified on one or the other). + + Some objects might be used in a read-only way on both the host and device + after they are instantiated. Derived objects like this should set the + "constant" constructor argument to True. For constant objects, update + host / device methods are a no-op. The same is true for accel_delete. + When a constant object is deleted / garbage collected, the derived class + _accel_delete() method is call to actually do the deletion. + Args: - None + constant (bool): If True, the host and device copies of the object + can be safely used simultaneously. """ - def __init__(self): + def __init__(self, constant=False): # Data always starts off on host self._accel_used = False self._accel_name = "(blank)" + self._constant = constant + + def __del__(self): + if self.accel_exists(): + self._accel_delete() def _accel_exists(self): return False @@ -395,10 +469,12 @@ def accel_update_device(self): """Copy the data to the accelerator. Returns: - None + (events): backend-specific events or None + """ + ret = None if not accel_enabled(): - return + return ret if (not self.accel_exists()) and (not use_accel_jax): # There is no data on device # NOTE: this does no apply to JAX as JAX will allocate on the fly @@ -412,8 +488,20 @@ def accel_update_device(self): msg = f"Active data is already on device, cannot update" log.error(msg) raise RuntimeError(msg) - self._accel_update_device() + if use_accel_opencl: + if not self._constant: + evs = self._accel_update_device() + if isinstance(evs, list): + ret = evs + else: + ret = [evs] + else: + ret = list() + else: + if not self._constant: + _ = self._accel_update_device() self.accel_used(True) + return ret def _accel_update_host(self): msg = f"The _accel_update_host function was not defined for this class." @@ -423,10 +511,12 @@ def accel_update_host(self): """Copy the data to the host from the accelerator. Returns: - None + (events): backend-specific events or None + """ + ret = None if not accel_enabled(): - return + return ret if not self.accel_exists(): log = Logger.get() msg = f"Data does not exist on device, cannot update host" @@ -438,8 +528,20 @@ def accel_update_host(self): msg = f"Active data is already on host, cannot update" log.error(msg) raise RuntimeError(msg) - self._accel_update_host() + if use_accel_opencl: + if not self._constant: + evs = self._accel_update_host() + if isinstance(evs, list): + ret = evs + else: + ret = [evs] + else: + ret = list() + else: + if not self._constant: + _ = self._accel_update_host() self.accel_used(False) + return ret def _accel_delete(self): msg = f"The _accel_delete function was not defined for this class." @@ -460,7 +562,8 @@ def accel_delete(self): msg = f"Data does not exist on device, cannot delete" log.error(msg) raise RuntimeError(msg) - self._accel_delete() + if not self._constant: + self._accel_delete() self._accel_used = False def _accel_reset(self): diff --git a/src/toast/accelerator/data_localization.py b/src/toast/accelerator/data_localization.py index 28ee8b2a4..003221029 100644 --- a/src/toast/accelerator/data_localization.py +++ b/src/toast/accelerator/data_localization.py @@ -9,9 +9,10 @@ # --------------------------------------------------------------------------- # RECORDING CLASS -use_debug_assert = ("TOAST_LOGLEVEL" in os.environ) and ( - os.environ["TOAST_LOGLEVEL"] in ["DEBUG", "VERBOSE"] -) +use_debug_assert = False +# use_debug_assert = ("TOAST_LOGLEVEL" in os.environ) and ( +# os.environ["TOAST_LOGLEVEL"] in ["DEBUG", "VERBOSE"] +# ) """ Assert is used only if `TOAST_LOGLEVEL` is set to `DEBUG`. """ diff --git a/src/toast/accelerator/kernel_registry.py b/src/toast/accelerator/kernel_registry.py index c4a4e7ea0..6ae77d6cb 100644 --- a/src/toast/accelerator/kernel_registry.py +++ b/src/toast/accelerator/kernel_registry.py @@ -18,6 +18,7 @@ class ImplementationType(IntEnum): COMPILED = 1 NUMPY = 2 JAX = 3 + OPENCL = 4 registry = dict() diff --git a/src/toast/data.py b/src/toast/data.py index a33771d39..d8253bc1e 100644 --- a/src/toast/data.py +++ b/src/toast/data.py @@ -464,24 +464,25 @@ def accel_create(self, names): return for ob in self.obs: - for objname, objmgr in [ - ("detdata", ob.detdata), - ("shared", ob.shared), - ("intervals", ob.intervals), - ]: - for key in names[objname]: - if key not in objmgr: - msg = f"ob {ob.name} {objname} accel_create '{key}' " - msg += f"not present, ignoring" - log.verbose(msg) - continue - if objmgr.accel_exists(key): - msg = f"ob {ob.name} {objname}: accel_create '{key}'" - msg += f" already exists" - log.verbose(msg) - else: - log.verbose(f"ob {ob.name} {objname}: accel_create '{key}'") - objmgr.accel_create(key) + ob.accel_create(names) + # for objname, objmgr in [ + # ("detdata", ob.detdata), + # ("shared", ob.shared), + # ("intervals", ob.intervals), + # ]: + # for key in names[objname]: + # if key not in objmgr: + # msg = f"ob {ob.name} {objname} accel_create '{key}' " + # msg += f"not present, ignoring" + # log.verbose(msg) + # continue + # if objmgr.accel_exists(key): + # msg = f"ob {ob.name} {objname}: accel_create '{key}'" + # msg += f" already exists" + # log.verbose(msg) + # else: + # log.verbose(f"ob {ob.name} {objname}: accel_create '{key}'") + # objmgr.accel_create(key) for key in names["global"]: val = self._internal.get(key, None) @@ -506,37 +507,45 @@ def accel_update_device(self, names): names (dict): Dictionary of lists. Returns: - None + (dict): The optional list of events generated for each observation. """ if not accel_enabled(): return log = Logger.get() + events = dict() + first_ob = None for ob in self.obs: - for objname, objmgr in [ - ("detdata", ob.detdata), - ("shared", ob.shared), - ("intervals", ob.intervals), - ]: - for key in names[objname]: - if key not in objmgr: - msg = f"ob {ob.name} {objname} update_device key '{key}'" - msg += f" not present, ignoring" - log.verbose(msg) - continue - if not objmgr.accel_exists(key): - msg = f"ob {ob.name} {objname} update_device key '{key}'" - msg += f" does not exist on accelerator" - log.error(msg) - raise RuntimeError(msg) - if objmgr.accel_in_use(key): - msg = f"ob {ob.name} {objname}: skip update_device '{key}'" - msg += f" already in use" - log.verbose(msg) - else: - log.verbose(f"ob {ob.name} {objname}: update_device '{key}'") - objmgr.accel_update_device(key) + if first_ob is None: + first_ob = ob.name + events[ob.name] = ob.accel_update_device(names) + # events[ob.name] = list() + # for objname, objmgr in [ + # ("detdata", ob.detdata), + # ("shared", ob.shared), + # ("intervals", ob.intervals), + # ]: + # for key in names[objname]: + # if key not in objmgr: + # msg = f"ob {ob.name} {objname} update_device key '{key}'" + # msg += f" not present, ignoring" + # log.verbose(msg) + # continue + # if not objmgr.accel_exists(key): + # msg = f"ob {ob.name} {objname} update_device key '{key}'" + # msg += f" does not exist on accelerator" + # log.error(msg) + # raise RuntimeError(msg) + # if objmgr.accel_in_use(key): + # msg = f"ob {ob.name} {objname}: skip update_device '{key}'" + # msg += f" already in use" + # log.verbose(msg) + # else: + # log.verbose(f"ob {ob.name} {objname}: update_device '{key}'") + # ev = objmgr.accel_update_device(key) + # print(f"DATA extend with {objname}:{key} = {ev}") + # events[ob.name].extend(ev) for key in names["global"]: val = self._internal.get(key, None) @@ -547,11 +556,14 @@ def accel_update_device(self, names): log.verbose(msg) else: log.verbose(f"Calling Data update_device for '{key}'") - val.accel_update_device() + ev = val.accel_update_device() + # print(f"DATA extend with global:{key} = {ev}") + events[first_ob].extend(ev) else: msg = f"Data accel_update_device: '{key}' ({type(val)}) " msg += "is not an AcceleratorObject" log.verbose(msg) + return events def accel_update_host(self, names): """Copy a set of data objects to the host. @@ -563,37 +575,44 @@ def accel_update_host(self, names): names (dict): Dictionary of lists. Returns: - None + (dict): The optional list of events generated for each observation. """ if not accel_enabled(): return log = Logger.get() + events = dict() + first_ob = None for ob in self.obs: - for objname, objmgr in [ - ("detdata", ob.detdata), - ("shared", ob.shared), - ("intervals", ob.intervals), - ]: - for key in names[objname]: - if key not in objmgr: - msg = f"ob {ob.name} {objname} update_host key '{key}'" - msg += f" not present, ignoring" - log.verbose(msg) - continue - if not objmgr.accel_exists(key): - msg = f"ob {ob.name} {objname} update_host key '{key}'" - msg += f" does not exist on accelerator, ignoring" - log.verbose(msg) - continue - if not objmgr.accel_in_use(key): - msg = f"ob {ob.name} {objname}: skip update_host, '{key}'" - msg += f" already on host" - log.verbose(msg) - else: - log.verbose(f"ob {ob.name} {objname}: update_host '{key}'") - objmgr.accel_update_host(key) + if first_ob is None: + first_ob = ob.name + events[ob.name] = ob.accel_update_host(names) + # events[ob.name] = list() + # for objname, objmgr in [ + # ("detdata", ob.detdata), + # ("shared", ob.shared), + # ("intervals", ob.intervals), + # ]: + # for key in names[objname]: + # if key not in objmgr: + # msg = f"ob {ob.name} {objname} update_host key '{key}'" + # msg += f" not present, ignoring" + # log.verbose(msg) + # continue + # if not objmgr.accel_exists(key): + # msg = f"ob {ob.name} {objname} update_host key '{key}'" + # msg += f" does not exist on accelerator, ignoring" + # log.verbose(msg) + # continue + # if not objmgr.accel_in_use(key): + # msg = f"ob {ob.name} {objname}: skip update_host, '{key}'" + # msg += f" already on host" + # log.verbose(msg) + # else: + # log.verbose(f"ob {ob.name} {objname}: update_host '{key}'") + # ev = objmgr.accel_update_host(key) + # events[ob.name].extend(ev) for key in names["global"]: val = self._internal.get(key, None) @@ -604,11 +623,13 @@ def accel_update_host(self, names): log.verbose(msg) else: log.verbose(f"Calling Data update_host for '{key}'") - val.accel_update_host() + ev = val.accel_update_host() + events[first_ob].extend(ev) else: msg = f"Data accel_update_host: '{key}' ({type(val)}) " msg += "is not an AcceleratorObject" log.verbose(msg) + return events def accel_delete(self, names): """Delete a specific set of device objects @@ -628,24 +649,25 @@ def accel_delete(self, names): log = Logger.get() for ob in self.obs: - for objname, objmgr in [ - ("detdata", ob.detdata), - ("shared", ob.shared), - ("intervals", ob.intervals), - ]: - for key in names[objname]: - if key not in objmgr: - msg = f"ob {ob.name} {objname} accel_delete key '{key}'" - msg += f" not present, ignoring" - log.verbose(msg) - continue - if objmgr.accel_exists(key): - log.verbose(f"ob {ob.name} {objname}: accel_delete '{key}'") - objmgr.accel_delete(key) - else: - msg = f"ob {ob.name} {objname}: accel_delete '{key}'" - msg += f" not present on device" - log.verbose(msg) + ob.accel_delete(names) + # for objname, objmgr in [ + # ("detdata", ob.detdata), + # ("shared", ob.shared), + # ("intervals", ob.intervals), + # ]: + # for key in names[objname]: + # if key not in objmgr: + # msg = f"ob {ob.name} {objname} accel_delete key '{key}'" + # msg += f" not present, ignoring" + # log.verbose(msg) + # continue + # if objmgr.accel_exists(key): + # log.verbose(f"ob {ob.name} {objname}: accel_delete '{key}'") + # objmgr.accel_delete(key) + # else: + # msg = f"ob {ob.name} {objname}: accel_delete '{key}'" + # msg += f" not present on device" + # log.verbose(msg) for key in names["global"]: val = self._internal.get(key, None) diff --git a/src/toast/intervals.py b/src/toast/intervals.py index 4b54f2727..aef676909 100644 --- a/src/toast/intervals.py +++ b/src/toast/intervals.py @@ -15,6 +15,7 @@ accel_data_update_host, use_accel_jax, use_accel_omp, + use_accel_opencl, ) from .timing import function_timer from .utils import Logger @@ -41,8 +42,11 @@ def build_interval_dtype(): } ) - -interval_dtype = build_interval_dtype() +if use_accel_opencl: + import pyopencl as cl + interval_dtype = cl.tools.get_or_register_dtype("Interval", build_interval_dtype()) +else: + interval_dtype = build_interval_dtype() class IntervalList(Sequence, AcceleratorObject): @@ -354,7 +358,7 @@ def __or__(self, other): ) def _accel_exists(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: return accel_data_present(self.data, self._accel_name) elif use_accel_jax: # specialised for the INTERVALS_JAX dtype @@ -363,7 +367,7 @@ def _accel_exists(self): return False def _accel_create(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: self.data = accel_data_create(self.data, self._accel_name) elif use_accel_jax: # specialised for the INTERVALS_JAX dtype @@ -371,24 +375,33 @@ def _accel_create(self): self.data = INTERVALS_JAX(self.data) def _accel_update_device(self): + ret = None if use_accel_omp: - self.data = accel_data_update_device(self.data, self._accel_name) + _ = accel_data_update_device(self.data, self._accel_name) elif use_accel_jax: # specialised for the INTERVALS_JAX dtype # NOTE: this call is timed at the INTERVALS_JAX level self.data = INTERVALS_JAX(self.data) + elif use_accel_opencl: + dev_data = accel_data_update_device(self.data, self._accel_name) + ret = dev_data.events + return ret def _accel_update_host(self): + ret = None if use_accel_omp: - self.data = accel_data_update_host(self.data, self._accel_name) + _ = accel_data_update_host(self.data, self._accel_name) elif use_accel_jax: # specialised for the INTERVALS_JAX dtype # this moves the data back into a numpy array # NOTE: this call is timed at the INTERVALS_JAX level self.data = self.data.to_host() + elif use_accel_opencl: + ret = accel_data_update_host(self.data, self._accel_name) + return ret def _accel_delete(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: self.data = accel_data_delete(self.data, self._accel_name) elif use_accel_jax and self._accel_exists(): # Ensures data has been properly reset diff --git a/src/toast/mpi.py b/src/toast/mpi.py index 426d1830a..38ae7ad39 100644 --- a/src/toast/mpi.py +++ b/src/toast/mpi.py @@ -53,9 +53,11 @@ class MPI_Comm: use_mpi = False # Assign each process to an accelerator device -from .accelerator import accel_assign_device, use_accel_jax, use_accel_omp +from .accelerator import ( + accel_assign_device, use_accel_jax, use_accel_omp, use_accel_opencl +) -if use_accel_omp or use_accel_jax: +if use_accel_omp or use_accel_jax or use_accel_opencl: node_procs = 1 node_rank = 0 accel_gb = 2.0 diff --git a/src/toast/observation.py b/src/toast/observation.py index a29ed4036..3626e78a8 100644 --- a/src/toast/observation.py +++ b/src/toast/observation.py @@ -864,17 +864,55 @@ def accel_create(self, names): None """ + log = Logger.get() for key in names["detdata"]: self.detdata.accel_create(key) for key in names["shared"]: self.shared.accel_create(key) for key in names["intervals"]: self.intervals.accel_create(key) - for key, val in self._internal.items(): + for key in names["meta"]: + if key not in self._internal: + msg = f"{self.name}['{key}'] does not exist, not creating " + msg += "on accelerator" + log.debug(msg) + continue + val = self._internal[key] if isinstance(val, AcceleratorObject): if not val.accel_exists(): val.accel_create() + def accel_delete(self, names): + """Delete a set of data objects on the device. + + This takes a dictionary with the same format as those used by the Operator + provides() and requires() methods. + + Args: + names (dict): Dictionary of lists. + + Returns: + None + + """ + log = Logger.get() + for key in names["detdata"]: + self.detdata.accel_delete(key) + for key in names["shared"]: + self.shared.accel_delete(key) + for key in names["intervals"]: + self.intervals.accel_delete(key) + for key in names["meta"]: + if key not in self._internal: + msg = f"{self.name}['{key}'] does not exist, not deleting " + msg += "on accelerator" + log.debug(msg) + continue + val = self._internal[key] + if isinstance(val, AcceleratorObject): + if val.accel_exists(): + val.accel_delete() + def accel_update_device(self, names): """Copy data objects to the device. @@ -885,19 +923,36 @@ def accel_update_device(self, names): names (dict): Dictionary of lists. Returns: - None + (list): The events """ + log = Logger.get() + events = list() for key in names["detdata"]: - self.detdata.accel_update_device(key) + evs = self.detdata.accel_update_device(key) + if evs is not None: + events.extend(evs) for key in names["shared"]: - self.shared.accel_update_device(key) + evs = self.shared.accel_update_device(key) + if evs is not None: + events.extend(evs) for key in names["intervals"]: - self.intervals.accel_update_device(key) - for key, val in self._internal.items(): + evs = self.intervals.accel_update_device(key) + if evs is not None: + events.extend(evs) + for key in names["meta"]: + if key not in self._internal: + msg = f"{self.name}['{key}'] does not exist, not updating " + msg += "accelerator copy" + log.debug(msg) + continue + val = self._internal[key] if isinstance(val, AcceleratorObject): if not val.accel_in_use(): - val.accel_update_device() + evs = val.accel_update_device() + if evs is not None: + events.extend(evs) + return events def accel_update_host(self, names): """Copy data objects from the device. @@ -909,19 +964,36 @@ def accel_update_host(self, names): names (dict): Dictionary of lists. Returns: - None + (list): The events """ + log = Logger.get() + events = list() for key in names["detdata"]: - self.detdata.accel_update_host(key) + evs = self.detdata.accel_update_host(key) + if evs is not None: + events.extend(evs) for key in names["shared"]: - self.shared.accel_update_host(key) + evs = self.shared.accel_update_host(key) + if evs is not None: + events.extend(evs) for key in names["intervals"]: - self.intervals.accel_update_host(key) - for key, val in self._internal.items(): + evs = self.intervals.accel_update_host(key) + if evs is not None: + events.extend(evs) + for key in names["meta"]: + if key not in self._internal: + msg = f"{self.name}['{key}'] does not exist, not updating " + msg += "host copy" + log.debug(msg) + continue + val = self._internal[key] if isinstance(val, AcceleratorObject): if val.accel_in_use(): - val.accel_update_host() + evs = val.accel_update_host() + if evs is not None: + events.extend(evs) + return events def accel_clear(self): self.detdata.accel_clear() diff --git a/src/toast/observation_data.py b/src/toast/observation_data.py index e1398b9b0..a4f772888 100644 --- a/src/toast/observation_data.py +++ b/src/toast/observation_data.py @@ -20,6 +20,7 @@ accel_enabled, use_accel_jax, use_accel_omp, + use_accel_opencl, ) from .intervals import IntervalList from .mpi import MPI, comm_equivalent @@ -324,14 +325,6 @@ def clear(self): are no longer being used and you are about to delete the object. """ - # first delete potential GPU data - if self.accel_exists(): - log = Logger.get() - msg = "clear() of DetectorData which is staged to accelerator- " - msg += "Deleting device copy." - log.verbose(msg) - self.accel_delete() - # then apply clear if hasattr(self, "_data"): del self._data self._data = None @@ -367,6 +360,7 @@ def reset(self, dets=None): self[d, :] = 0 def __del__(self): + super().__del__() self.clear() def _det_axis_view(self, key): @@ -556,7 +550,7 @@ def _accel_exists(self): # object and use that. return False else: - if use_accel_omp: + if use_accel_omp or use_accel_opencl: return accel_data_present(self._raw, self._accel_name) elif use_accel_jax: return accel_data_present(self._data) @@ -564,7 +558,8 @@ def _accel_exists(self): return False def _accel_create(self, zero_out=False): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: + print(f"DD create {self._accel_name}") self._raw = accel_data_create( self._raw, self._accel_name, zero_out=zero_out ) @@ -573,24 +568,32 @@ def _accel_create(self, zero_out=False): def _accel_update_device(self): if use_accel_omp: - self._raw = accel_data_update_device(self._raw, self._accel_name) + _ = accel_data_update_device(self._raw, self._accel_name) elif use_accel_jax: self._data = accel_data_update_device(self._data) + elif use_accel_opencl: + dev_data = accel_data_update_device(self._raw, self._accel_name) + print(f"DD update device {self._accel_name}, evs={dev_data.events}") + return dev_data.events def _accel_update_host(self): if use_accel_omp: - self._raw = accel_data_update_host(self._raw, self._accel_name) + _ = accel_data_update_host(self._raw, self._accel_name) elif use_accel_jax: self._data = accel_data_update_host(self._data) + elif use_accel_opencl: + evs = accel_data_update_host(self._raw, self._accel_name) + print(f"DD update host {self._accel_name}, evs={evs}") + return evs def _accel_delete(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: self._raw = accel_data_delete(self._raw, self._accel_name) elif use_accel_jax: self._data = accel_data_delete(self._data) def _accel_reset(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: accel_data_reset(self._raw, self._accel_name) elif use_accel_jax: self._data = accel_data_reset(self._data) @@ -930,7 +933,7 @@ def accel_update_device(self, key): log.verbose( f"DetDataMgr {key} type = {type(self._internal[key])} accel_update_device" ) - self._internal[key].accel_update_device() + return self._internal[key].accel_update_device() def accel_update_host(self, key): """Copy the named detector data from the accelerator. @@ -948,7 +951,7 @@ def accel_update_host(self, key): log.verbose( f"DetDataMgr {key} type = {type(self._internal[key])} accel_update_host" ) - self._internal[key].accel_update_host() + return self._internal[key].accel_update_host() def accel_delete(self, key): """Delete the named detector data from the accelerator. @@ -1528,7 +1531,7 @@ def accel_exists(self, key): log.error(msg) raise RuntimeError(msg) - if use_accel_omp: + if use_accel_omp or use_accel_opencl: result = accel_data_present(self._internal[key].shdata._flat, key) elif use_accel_jax: result = accel_data_present(self._internal[key].shdata.data) @@ -1595,7 +1598,7 @@ def accel_create(self, key): raise RuntimeError(msg) log.verbose(f"SharedDataMgr {key} accel_create") - if use_accel_omp: + if use_accel_omp or use_accel_opencl: _ = accel_data_create(self._internal[key].shdata._flat, key) elif use_accel_jax: self._internal[key].shdata.data = MutableJaxArray( @@ -1631,14 +1634,18 @@ def accel_update_device(self, key): raise RuntimeError(msg) log.verbose(f"SharedDataMgr {key} accel_update_device") + ret = None if use_accel_omp: _ = accel_data_update_device(self._internal[key].shdata._flat, key) elif use_accel_jax: self._internal[key].shdata.data = MutableJaxArray( self._internal[key].shdata.data ) - + elif use_accel_opencl: + dev_data = accel_data_update_device(self._internal[key].shdata._flat, key) + ret = dev_data.events self._accel_used[key] = True + return ret def accel_update_host(self, key): """Copy the named shared data from the accelerator to the host. @@ -1669,14 +1676,17 @@ def accel_update_host(self, key): raise RuntimeError(msg) log.verbose(f"SharedDataMgr {key} accel_update_host") - if use_accel_omp: + ret = None + if use_accel_omp or use_accel_opencl: _ = accel_data_update_host(self._internal[key].shdata._flat, key) elif use_accel_jax: self._internal[key].shdata.data = accel_data_update_host( self._internal[key].shdata.data ) - + elif use_accel_opencl: + ret = accel_data_update_host(self._internal[key].shdata._flat, key) self._accel_used[key] = False + return ret def accel_delete(self, key): """Delete the named data object on the device @@ -1701,7 +1711,7 @@ def accel_delete(self, key): raise RuntimeError(msg) log.verbose(f"SharedDataMgr {key} accel_delete") - if use_accel_omp: + if use_accel_omp or use_accel_opencl: _ = accel_data_delete(self._internal[key].shdata._flat, key) elif use_accel_jax: self._internal[key].shdata.data = accel_data_delete( @@ -1723,7 +1733,7 @@ def accel_clear(self): for key in self._internal: if self.accel_exists(key): log.verbose(f"SharedDataMgr {key} accel_delete") - if use_accel_omp: + if use_accel_omp or use_accel_opencl: _ = accel_data_delete(self._internal[key].shdata._flat, key) elif use_accel_jax: self._internal[key].shdata.data = accel_data_delete( @@ -2243,7 +2253,7 @@ def accel_update_device(self, key): return log = Logger.get() log.verbose(f"IntervalsManager {key} accel_update_device") - self[key].accel_update_device() + return self[key].accel_update_device() def accel_update_host(self, key): """Copy the named interval list from the accelerator. @@ -2259,7 +2269,7 @@ def accel_update_host(self, key): return log = Logger.get() log.verbose(f"IntervalsManager {key} accel_update_host") - self[key].accel_update_host() + return self[key].accel_update_host() def accel_delete(self, key): """Delete the named interval list from the accelerator. diff --git a/src/toast/opencl/CMakeLists.txt b/src/toast/opencl/CMakeLists.txt new file mode 100644 index 000000000..ce5396777 --- /dev/null +++ b/src/toast/opencl/CMakeLists.txt @@ -0,0 +1,9 @@ + +# Install the python files + +install(FILES + __init__.py + utils.py + platform.py + DESTINATION ${PYTHON_SITE}/toast/opencl +) diff --git a/src/toast/opencl/__init__.py b/src/toast/opencl/__init__.py new file mode 100644 index 000000000..45e37c8a4 --- /dev/null +++ b/src/toast/opencl/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. +"""OpenCL tools. +""" + +from .utils import ( + have_opencl, + find_source, + get_kernel_deps, + add_kernel_deps, + replace_kernel_deps, + clear_kernel_deps, +) +from .platform import OpenCL diff --git a/src/toast/opencl/platform.py b/src/toast/opencl/platform.py new file mode 100644 index 000000000..6ddc52cbb --- /dev/null +++ b/src/toast/opencl/platform.py @@ -0,0 +1,788 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. +"""OpenCL platform tools. +""" +import os +import ctypes +import numpy as np + +from ..utils import Environment, Logger +from .utils import have_opencl, aligned_to_dtype, find_source + +if have_opencl: + import pyopencl as cl + from pyopencl.array import Array + + +class OpenCL: + """Singleton class to manage OpenCL usage. + + This class provides a global interface to the underlying OpenCL platforms, + compiled programs, and memory management. + + The default device type can be controlled with environment variables: + + TOAST_OPENCL_DEFAULT= + + Where supported values are "CPU", "GPU", and "OCLGRIND". + + """ + + instance = None + log_prefix = "OpenCL: " + + def __new__(cls): + if cls.instance is None: + cls.instance = super().__new__(cls) + cls.instance._initialized = False + return cls.instance + + def __init__(self): + if self._initialized: + return + log = Logger.get() + if not have_opencl: + log.error("pyopencl is not available!") + self._platforms = cl.get_platforms() + + self._cpus = list() + self._cpu_names = dict() + self._cpu_index = dict() + self._gpus = list() + self._gpu_names = dict() + self._gpu_index = dict() + self._grind = list() + self._grind_names = dict() + self._grind_index = dict() + + # The paradigm we use for multiple platforms is: + # - For the first platform with CPUs we index all CPU devices from there, + # in case they are duplicated in other platforms. + # - We index all GPUs across all platforms. + + found_cpus = False + cpu_offset = 0 + gpu_offset = 0 + grind_offset = 0 + for iplat, plat in enumerate(self._platforms): + devices = plat.get_devices() + if not found_cpus: + use_cpus = True + else: + use_cpus = False + for dev in devices: + if dev.type == 2: + # This is a CPU + found_cpus = True + if use_cpus: + self._cpus.append( + { + "name": dev.name, + "device": dev, + "platform": plat, + "type": "cpu", + "index": cpu_offset, + } + ) + self._cpu_index[dev.name] = cpu_offset + self._cpu_names[cpu_offset] = dev.name + cpu_offset += 1 + elif dev.type == 4: + # This is a GPU + self._gpus.append( + { + "name": dev.name, + "device": dev, + "platform": plat, + "type": "gpu", + "index": gpu_offset, + } + ) + self._gpu_index[dev.name] = gpu_offset + self._gpu_names[gpu_offset] = dev.name + gpu_offset += 1 + else: + # This is something else (OCLGRIND) + self._grind.append( + { + "name": dev.name, + "device": dev, + "platform": plat, + "type": "oclgrind", + "index": grind_offset, + } + ) + self._grind_index[dev.name] = grind_offset + self._grind_names[grind_offset] = dev.name + grind_offset += 1 + + self._n_cpu = cpu_offset + self._n_gpu = gpu_offset + self._n_grind = grind_offset + if self._n_gpu == 0: + self._default_gpu = -1 + else: + self._default_gpu = 0 + if self._n_cpu == 0: + self._default_cpu = -1 + else: + self._default_cpu = 0 + if self._n_grind == 0: + self._default_grind = -1 + else: + self._default_grind = 0 + + # Create contexts, queues and memory managers. We split these + # lists of devices in case we want to configure things differently + # in the future. + for d in self._cpus: + d["context"] = cl.Context( + devices=[d["device"]], + properties=[(cl.context_properties.PLATFORM, d["platform"])], + ) + # Default queue + d["queue"] = cl.CommandQueue(d["context"], d["device"]) + d["allocator"] = cl.tools.ImmediateAllocator( + d["queue"], mem_flags=cl.mem_flags.READ_WRITE + ) + d["mempool"] = cl.tools.MemoryPool(d["allocator"]) + for d in self._gpus: + d["context"] = cl.Context( + devices=[d["device"]], + properties=[(cl.context_properties.PLATFORM, d["platform"])], + ) + # Default queue + d["queue"] = cl.CommandQueue(d["context"], d["device"]) + d["allocator"] = cl.tools.ImmediateAllocator( + d["queue"], mem_flags=cl.mem_flags.READ_WRITE + ) + d["mempool"] = cl.tools.MemoryPool(d["allocator"]) + for d in self._grind: + d["context"] = cl.Context( + devices=[d["device"]], + properties=[(cl.context_properties.PLATFORM, d["platform"])], + ) + # Default queue + d["queue"] = cl.CommandQueue(d["context"], d["device"]) + d["allocator"] = cl.tools.ImmediateAllocator( + d["queue"], mem_flags=cl.mem_flags.READ_WRITE + ) + d["mempool"] = cl.tools.MemoryPool(d["allocator"]) + + # Indexed by program name and then device + self._programs = dict() + + # Indexed by host buffer address and then device + self._buffers = dict() + + # User override of default device type + if "TOAST_OPENCL_DEFAULT" in os.environ: + self._default_dev_type = os.environ["TOAST_OPENCL_DEFAULT"] + if self._default_dev_type not in ["CPU", "GPU", "OCLGRIND"]: + msg = "Unknown user specified default device type " + msg += "'{self._default_dev_type}'" + raise RuntimeError(msg) + self._default_dev_type = self._default_dev_type.lower() + else: + if self._n_gpu > 0: + # We have some GPUs, that is probably the intended default + self._default_dev_type = "gpu" + else: + self._default_dev_type = "cpu" + + self._initialized = True + + # Create "Null" buffers. Often we need to pass an optional device + # array to kernels, and if they are not going to be used we still want + # to pass a valid pointer. So we pre-create some fake buffers for that + # purpose. + self._mem_null_host = dict() + self._mem_null_dev = dict() + for dt in [ + np.dtype(np.uint8), + np.dtype(np.int8), + np.dtype(np.uint16), + np.dtype(np.int16), + np.dtype(np.uint32), + np.dtype(np.int32), + np.dtype(np.uint64), + np.dtype(np.int64), + np.dtype(np.float32), + np.dtype(np.float64), + ]: + self._mem_null_host[dt] = np.zeros(1, dtype=dt) + self._mem_null_dev[dt] = dict() + for dev_type in ["cpu", "gpu", "oclgrind"]: + self._mem_null_dev[dt][dev_type] = dict() + for idx, d in enumerate(self._cpus): + self._mem_null_dev[dt]["cpu"][idx] = self.mem_to_device( + self._mem_null_host[dt], + name=f"NULL_{dt}", + device_type="cpu", + device_index=idx, + async_=False, + ) + for idx, d in enumerate(self._gpus): + self._mem_null_dev[dt]["gpu"][idx] = self.mem_to_device( + self._mem_null_host[dt], + name=f"NULL_{dt}", + device_type="gpu", + device_index=idx, + async_=False, + ) + for idx, d in enumerate(self._grind): + self._mem_null_dev[dt]["oclgrind"][idx] = self.mem_to_device( + self._mem_null_host[dt], + name=f"NULL_{dt}", + device_type="oclgrind", + device_index=idx, + async_=False, + ) + + def __del__(self): + # Free buffers + if hasattr(self, "_buffers"): + for haddr, devs in self._buffers.items(): + for dname, (dbuf, dsize, bname) in devs.items(): + dbuf.finish() + del dbuf + devs.clear() + self._buffers.clear() + + # Free kernels and programs + if hasattr(self, "_programs"): + self._programs.clear() + + # Free memory pools, queues and contexts + if ( + hasattr(self, "_cpus") + and hasattr(self, "_gpus") + and hasattr(self, "_grind") + ): + for devs in [self._cpus, self._gpus, self._grind]: + for d in devs: + if "mempool" in d: + d["mempool"].free_held() + d["mempool"].stop_holding() + del d["mempool"] + if "allocator" in d: + del d["allocator"] + if "queue" in d: + d["queue"].flush() + d["queue"].finish() + del d["queue"] + + def get_device(self, device_type=None, device_index=None, device_name=None): + """Lookup a device by type and either index or name. + + If the name and index are not specified, the default device of the + selected type is returned. + + Args: + device_type (str): "cpu" or "gpu" or "oclgrind" + device_index (int): The index within either the CPU or GPU list. + device_name (str): The specific (and usually very long) device name. + + Returns: + (dict): The selected device properties. + + """ + if device_name is not None and device_index is not None: + msg = "At most, one of device_name or device_index may be specified" + raise RuntimeError(msg) + if device_type is None: + device_type = self._default_dev_type + if device_type == "cpu": + if device_name is not None: + if device_name not in self._cpu_index: + msg = f"CPU device '{device_name}' does not exist" + raise RuntimeError(msg) + device_index = self._cpu_index[device_name] + else: + if device_index is None: + device_index = self._default_cpu + return self._cpus[device_index] + elif device_type == "gpu": + if device_name is not None: + if device_name not in self._gpu_index: + msg = f"GPU device '{device_name}' does not exist" + raise RuntimeError(msg) + device_index = self._gpu_index[device_name] + else: + if device_index is None: + device_index = self._default_gpu + return self._gpus[device_index] + elif device_type == "oclgrind": + if device_name is not None: + if device_name not in self._grind_index: + msg = f"OCLGRIND device '{device_name}' does not exist" + raise RuntimeError(msg) + device_index = self._grind_index[device_name] + else: + if device_index is None: + device_index = self._default_grind + return self._grind[device_index] + else: + msg = f"Unknown device type '{device_type}'" + raise RuntimeError(msg) + + @property + def n_cpu(self): + return self._n_cpu + + @property + def n_gpu(self): + return self._n_gpu + + @property + def n_oclgrind(self): + return self._n_grind + + @property + def default_device_type(self): + return self._default_dev_type + + @property + def default_gpu_index(self): + return self._default_gpu + + @property + def default_cpu_index(self): + return self._default_cpu + + @property + def default_oclgrind_index(self): + return self._default_grind + + def info(self): + """Print information about the general status of the OpenCL layer.""" + env = Environment.get() + level = env.log_level() + + msg = "" + for idev, dev in enumerate(self._cpus): + msg += f"{self.log_prefix} CPU {idev} ({dev['name']})\n" + msg += f"{self.log_prefix} platform {dev['platform']}\n" + for idev, dev in enumerate(self._gpus): + msg += f"{self.log_prefix} GPU {idev} ({dev['name']})\n" + msg += f"{self.log_prefix} platform {dev['platform']}\n" + for idev, dev in enumerate(self._grind): + msg += f"{self.log_prefix} OCLGRIND {idev} ({dev['name']})\n" + msg += f"{self.log_prefix} platform {dev['platform']}\n" + msg += f"{self.log_prefix} Default CPU = {self._default_cpu}\n" + msg += f"{self.log_prefix} Default GPU = {self._default_gpu}\n" + msg += f"{self.log_prefix} Default OCLGRIND = {self._default_grind}\n" + print(msg, flush=True) + + def set_default_gpu(self, device_name=None, device_index=None): + if self._n_gpu == 0: + msg = "Cannot set default GPU, none are detected!" + raise RuntimeError(msg) + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type="gpu" + ) + self._default_gpu = dev["index"] + + def set_default_cpu(self, device_name=None, device_index=None): + if self._n_cpu == 0: + msg = "Cannot set default CPU, none are detected!" + raise RuntimeError(msg) + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type="cpu" + ) + self._default_cpu = dev["index"] + + def set_default_oclgrind(self, device_name=None, device_index=None): + if self._n_grind == 0: + msg = "Cannot set default OCLGRIND, none are detected!" + raise RuntimeError(msg) + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type="oclgrind" + ) + self._default_grind = dev["index"] + + def assign_default_devices(self, node_procs, node_rank, disabled): + if self.n_cpu > 0: + # Our platforms support CPUs + proc_per_cpu = node_procs // self._n_cpu + if self._n_cpu * proc_per_cpu < node_procs: + proc_per_cpu += 1 + target = node_rank // proc_per_cpu + self.set_default_cpu(device_index=target) + if self.n_gpu > 0: + # Our platforms support GPUs + proc_per_gpu = node_procs // self._n_gpu + if self._n_gpu * proc_per_gpu < node_procs: + proc_per_gpu += 1 + target = node_rank // proc_per_gpu + self.set_default_gpu(device_index=target) + env = Environment.get() + env.set_acc(self.n_gpu, proc_per_gpu, target) + if self.n_oclgrind > 0: + # Our platforms support OCLGRIND fake devices + proc_per_grind = node_procs // self._n_grind + if self._n_grind * proc_per_grind < node_procs: + proc_per_grind += 1 + target = node_rank // proc_per_grind + self.set_default_oclgrind(device_index=target) + + def build_program(self, program_name, source): + """Load the program source and build for all devices.""" + with open(source, "r") as f: + clstr = f.read() + self._programs[program_name] = dict() + for d in self._cpus: + dname = d["name"] + self._programs[program_name][dname] = cl.Program(d["context"], clstr) + self._programs[program_name][dname].build() + build_status = self._programs[program_name][dname].get_build_info( + d["device"], cl.program_build_info.STATUS + ) + build_log = self._programs[program_name][dname].get_build_info( + d["device"], cl.program_build_info.LOG + ) + for d in self._gpus: + dname = d["name"] + self._programs[program_name][dname] = cl.Program(d["context"], clstr) + try: + self._programs[program_name][dname].build() + except: + pass + build_status = self._programs[program_name][dname].get_build_info( + d["device"], cl.program_build_info.STATUS + ) + build_log = self._programs[program_name][dname].get_build_info( + d["device"], cl.program_build_info.LOG + ) + for d in self._grind: + dname = d["name"] + self._programs[program_name][dname] = cl.Program(d["context"], clstr) + self._programs[program_name][dname].build() + build_status = self._programs[program_name][dname].get_build_info( + d["device"], cl.program_build_info.STATUS + ) + build_log = self._programs[program_name][dname].get_build_info( + d["device"], cl.program_build_info.LOG + ) + + def context(self, device_name=None, device_index=None, device_type=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type=device_type + ) + return dev["context"] + + def queue(self, device_name=None, device_index=None, device_type=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type=device_type + ) + return dev["queue"] + + def has_kernel( + self, + program_name, + kernel_name, + device_name=None, + device_index=None, + device_type=None, + ): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type=device_type + ) + if program_name not in self._programs: + return False + devname = dev["name"] + if devname not in self._programs[program_name]: + return False + prog = self._programs[program_name][devname] + if not hasattr(prog, kernel_name): + return False + return True + + def kernel( + self, + program_name, + kernel_name, + device_name=None, + device_index=None, + device_type=None, + ): + if device_type is None: + device_type = self._default_dev_type + exists = self.has_kernel( + program_name, + kernel_name, + device_name=device_name, + device_index=device_index, + device_type=device_type, + ) + if not exists: + msg = f"kernel {kernel_name} in program {program_name} does not exist" + raise RuntimeError(msg) + dev = self.get_device( + device_name=device_name, device_index=device_index, device_type=device_type + ) + devname = dev["name"] + prog = self._programs[program_name][devname] + return getattr(prog, kernel_name) + + def get_or_build_kernel( + self, + program_name, + kernel_name, + device_name=None, + device_index=None, + device_type=None, + source=None, + ): + if device_type is None: + device_type = self._default_dev_type + if not self.has_kernel( + program_name, + kernel_name, + device_name=device_name, + device_index=device_index, + device_type=device_type, + ): + if source is None: + # Look for a source file named after the program in our + # common directory + source = find_source(os.path.dirname(__file__), f"{program_name}.cl") + self.build_program(program_name, source) + return self.kernel( + program_name, + kernel_name, + device_name=device_name, + device_index=device_index, + device_type=device_type, + ) + + def _mem_host_props(self, host_buffer): + if hasattr(host_buffer, "address"): + # This is a C-allocated buffer + haddr = host_buffer.address() + hsize = host_buffer.size() + # These are always 1D + hshape = (hsize,) + htype = aligned_to_dtype(host_buffer) + harray = host_buffer.array() + else: + haddr = host_buffer.ctypes.data + hsize = host_buffer.size + hshape = host_buffer.shape + htype = host_buffer.dtype + harray = host_buffer + return haddr, hsize, hshape, htype, harray + + def _mem_check_props(self, haddr, hsize, hname, dbuf, dsize, dname, err=False): + log = Logger.get() + # if dname != hname: + # msg = f"Host buffer {haddr} has device name '{dname}' not '{hname}'" + # if err: + # raise RuntimeError(msg) + # else: + # log.warning(msg) + if dsize != hsize: + msg = f"Host buffer {haddr} has device size {dsize} not {hsize}" + if err: + raise RuntimeError(msg) + else: + log.warning(msg) + + def _mem_check_exists(self, haddr, hsize, name, dev, err=False, warn=False): + log = Logger.get() + if haddr not in self._buffers: + msg = f"Host buffer {haddr} (name={name}) not registered" + if err: + raise RuntimeError(msg) + elif warn: + log.warning(msg) + return (None, None) + dprops = self._buffers[haddr] + dev_name = dev["name"] + if dev_name not in dprops: + msg = f"Host buffer {haddr} (name={name}) not on device '{dev_name}'" + if err: + raise RuntimeError(msg) + elif warn: + log.warning(msg) + return (None, None) + dbuf, dsize, dname = dprops[dev_name] + self._mem_check_props(haddr, hsize, name, dbuf, dsize, dname, err=err) + return dbuf, dsize + + def mem_create(self, host_buffer, name=None, device_type=None, device_index=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + if haddr not in self._buffers: + self._buffers[haddr] = dict() + dprops = self._buffers[haddr] + dev_name = dev["name"] + if dev_name in dprops: + dbuf, dsize, dname = dprops[dev_name] + msg = f"Host buffer {haddr} already registered on device " + msg += f"'{dev_name}' at address {dbuf.base_data} with size {dsize}" + msg += f" and name {dname}" + raise RuntimeError(msg) + dbuf = Array( + dev["queue"], + hshape, + htype, + order="C", + allocator=dev["mempool"], + ) + self._buffers[haddr][dev_name] = ( + dbuf, + hsize, + name, + ) + return dbuf + + def mem_to_device( + self, + host_buffer, + name=None, + device_type=None, + device_index=None, + async_=False, + ): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + if haddr not in self._buffers: + self._buffers[haddr] = dict() + dprops = self._buffers[haddr] + dev_name = dev["name"] + if dev_name in dprops: + # Check that the existing array has correct properties + dbuf, dsize, dname = dprops[dev_name] + self._mem_check_props(haddr, hsize, name, dbuf, dsize, dname, err=True) + else: + dbuf = Array( + dev["queue"], + hshape, + htype, + order="C", + allocator=dev["mempool"], + ) + self._buffers[haddr][dev_name] = ( + dbuf, + hsize, + name, + ) + dbuf.set(harray, async_=async_) + return dbuf + + def mem_remove(self, host_buffer, name=None, device_type=None, device_index=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + dbuf, dsize = self._mem_check_exists(haddr, hsize, name, dev, warn=True) + if dbuf is None: + return + dbuf.finish() + del dbuf + del self._buffers[haddr][dev["name"]] + if len(self._buffers[haddr]) == 0: + del self._buffers[haddr] + + def mem_update_device( + self, + host_buffer, + name=None, + device_type=None, + device_index=None, + async_=False, + ): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + dbuf, dsize = self._mem_check_exists(haddr, hsize, name, dev, err=True) + dbuf.set(harray, async_=async_) + return dbuf + + def mem_update_host( + self, + host_buffer, + name=None, + device_type=None, + device_index=None, + async_=False, + ): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + dbuf, dsize = self._mem_check_exists(haddr, hsize, name, dev, err=True) + if async_: + _, ev = dbuf.get_async(ary=harray) + else: + dbuf.get(ary=harray) + ev = None + return ev + + def mem_reset( + self, + host_buffer, + name=None, + device_type=None, + device_index=None, + wait_for=None, + ): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + dbuf, dsize = self._mem_check_exists(haddr, hsize, name, dev, err=True) + dbuf.fill(0, wait_for=wait_for) + + def mem_present(self, host_buffer, name=None, device_type=None, device_index=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + dbuf, dsize = self._mem_check_exists(haddr, hsize, name, dev) + if dbuf is None: + return False + else: + return True + + def mem(self, host_buffer, name=None, device_type=None, device_index=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + dbuf, dsize = self._mem_check_exists(haddr, hsize, name, dev, err=True) + return dbuf + + def mem_null(self, host_buffer, device_type=None, device_index=None): + if device_type is None: + device_type = self._default_dev_type + dev = self.get_device(device_type=device_type, device_index=device_index) + haddr, hsize, hshape, htype, harray = self._mem_host_props(host_buffer) + return self._mem_null_dev[htype][dev["type"]][dev["index"]] + + def mem_dump(self): + msg = "" + prefix = f"{self.log_prefix}MEM:" + for haddr, devs in self._buffers.items(): + for dname, (dbuf, dsize, bname) in devs.items(): + if dname in self._cpu_index: + dstr = f"CPU[{self._cpu_index[dname]}]" + elif dname in self._gpu_index: + dstr = f"GPU[{self._gpu_index[dname]}]" + elif dname in self._grind_index: + dstr = f"OCLGRIND[{self._grind_index[dname]}]" + msg += f"{prefix} H[{haddr}] -> {dstr} size={dsize} ({bname})\n" + print(msg, flush=True) diff --git a/src/toast/opencl/utils.py b/src/toast/opencl/utils.py new file mode 100644 index 000000000..0d1ab30b2 --- /dev/null +++ b/src/toast/opencl/utils.py @@ -0,0 +1,215 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. +"""OpenCL utilities. +""" +import os +import numpy as np + +from ..utils import Logger + +# FIXME: remove this once we no longer need a compiled extension. +from .._libtoast import ( + AlignedF32, + AlignedF64, + AlignedI8, + AlignedI16, + AlignedI32, + AlignedI64, + AlignedU8, + AlignedU16, + AlignedU32, + AlignedU64, +) + +# Check if pyopencl is importable + +try: + import pyopencl as cl + + have_opencl = True +except Exception: + # There could be several possible exceptions... + have_opencl = False + log = Logger.get() + msg = "pyopencl is not importable- disabling" + log.debug(msg) + + +def find_source(calling_file, rel_path): + """Locate an OpenCL source file relative to another file. + + Args: + calling_file (str): The __FILE__ of the caller. + rel_path (str): The relative path. + + Returns: + (str): The path to the file (or None) + + """ + path = os.path.join(calling_file, rel_path) + apath = os.path.abspath(path) + if not os.path.isfile(apath): + msg = f"OpenCL source file '{apath}' does not exist" + raise RuntimeError(msg) + return apath + + +def aligned_to_dtype(aligned): + """Return the dtype for an internal Aligned class. + + Args: + aligned (class): The Aligned class. + + Returns: + (dtype): The equivalent dtype. + + """ + log = Logger.get() + if isinstance(aligned, AlignedI8): + return np.dtype(np.int8) + elif isinstance(aligned, AlignedU8): + return np.dtype(np.uint8) + elif isinstance(aligned, AlignedI16): + return np.dtype(np.int16) + elif isinstance(aligned, AlignedU16): + return np.dtype(np.uint16) + elif isinstance(aligned, AlignedI32): + return np.dtype(np.int32) + elif isinstance(aligned, AlignedU32): + return np.dtype(np.uint32) + elif isinstance(aligned, AlignedI64): + return np.dtype(np.int64) + elif isinstance(aligned, AlignedU64): + return np.dtype(np.uint64) + elif isinstance(aligned, AlignedF32): + return np.dtype(np.float32) + elif isinstance(aligned, AlignedF64): + return np.dtype(np.float64) + else: + msg = f"Unsupported Aligned data class '{aligned}'" + log.error(msg) + raise ValueError(msg) + + +def get_kernel_deps(state, obs_name): + """Extract kernel wait_for events for the current observation. + + Args: + state (dict): The state dictionary + obs_name (str): The observation name + + Returns: + (list): The list of events to wait on. + + """ + if obs_name is None: + msg = "Observation name cannot be None" + raise RuntimeError(msg) + if state is None: + # No dependencies + # print(f"GET {obs_name}: state is None", flush=True) + return list() + if not isinstance(state, dict): + msg = "kernel state should be a dictionary keyed on observation name" + raise RuntimeError(msg) + if obs_name not in state: + # No dependencies for this observation + # print(f"GET {obs_name}: obs_name not in state", flush=True) + return list() + # Return events + return state[obs_name] + + +def clear_kernel_deps(state, obs_name): + """Clear kernel events for a given observation. + + This should be done **after** the events are completed. + + Args: + state (dict): The state dictionary + obs_name (str): The observation name + + Returns: + None + + """ + if obs_name is None: + msg = "Observation name cannot be None" + raise RuntimeError(msg) + if state is None: + # No dependencies + return + if not isinstance(state, dict): + msg = "kernel state should be a dictionary keyed on observation name" + raise RuntimeError(msg) + if obs_name not in state: + # No dependencies for this observation + return + # Clear + state[obs_name].clear() + + +def replace_kernel_deps(state, obs_name, events): + """Clear the events for a given observation and replace. + + The event list for the specified observation is created if needed. + + Args: + state (dict): The state dictionary + obs_name (str): The observation name + events (Event, list): pyopencl event or list of events. + + Returns: + None + + """ + if obs_name is None: + msg = "Observation name cannot be None" + raise RuntimeError(msg) + if state is None: + msg = "State dictionary cannot be None" + raise RuntimeError(msg) + if not isinstance(state, dict): + msg = "kernel state should be a dictionary keyed on observation name" + raise RuntimeError(msg) + if obs_name in state: + state[obs_name].clear() + else: + state[obs_name] = list() + if events is None: + return + if isinstance(events, list): + state[obs_name].extend(events) + else: + state[obs_name].append(events) + + +def add_kernel_deps(state, obs_name, events): + """Append event(s) to the current observation state. + + The event list for the specified observation is created if needed. + + Args: + state (dict): The state dictionary + obs_name (str): The observation name + events (Event, list): pyopencl event or list of events. + + Returns: + None + + """ + if obs_name is None: + msg = "Observation name cannot be None" + raise RuntimeError(msg) + if state is None: + msg = "State dictionary cannot be None" + raise RuntimeError(msg) + if obs_name not in state: + state[obs_name] = list() + if events is None: + return + if isinstance(events, list): + state[obs_name].extend(events) + else: + state[obs_name].append(events) diff --git a/src/toast/ops/mapmaker_binning.py b/src/toast/ops/mapmaker_binning.py index 6b4107edd..e6a0ccee3 100644 --- a/src/toast/ops/mapmaker_binning.py +++ b/src/toast/ops/mapmaker_binning.py @@ -276,6 +276,9 @@ def _exec(self, data, detectors=None, **kwargs): log.verbose(" BinMap running pipeline") pipe_out = accum.apply(data, detectors=detectors) + good_pix = data[self.binned].data != 0 + print(f"Binned zmap = {data[self.binned].data[good_pix]}") + # print("Binned zmap = ", data[self.binned].data) # Optionally, store the noise-weighted map diff --git a/src/toast/ops/mapmaker_solve.py b/src/toast/ops/mapmaker_solve.py index f75dfcc1f..17b137716 100644 --- a/src/toast/ops/mapmaker_solve.py +++ b/src/toast/ops/mapmaker_solve.py @@ -127,6 +127,9 @@ def _exec(self, data, detectors=None, **kwargs): self.binning.det_data_units = self.det_data_units self.binning.apply(data, detectors=detectors) + good_pix = data[self.binning.binned].data != 0 + print(f"RHS binned = {data[self.binning.binned].data[good_pix]}") + log.debug_rank("MapMaker RHS binned map finished in", comm=comm, timer=timer) # Build a pipeline for the projection and template matrix application. @@ -220,6 +223,10 @@ def _exec(self, data, detectors=None, **kwargs): "MapMaker RHS begin cleanup temporary detector data", comm=comm ) + for tkey in data[self.template_matrix.amplitudes].keys(): + good_amps = data[self.template_matrix.amplitudes][tkey].local != 0 + print(f"RHS {tkey}: {data[self.template_matrix.amplitudes][tkey].local[good_amps]}") + # Clean up our temp buffer delete_temp = Delete(detdata=[det_temp]) delete_temp.apply(data) @@ -370,10 +377,37 @@ def _exec(self, data, detectors=None, **kwargs): timer.start() log.debug_rank("MapMaker LHS begin project amplitudes and binning", comm=comm) + # for tkey in data[self.template_matrix.amplitudes].keys(): + # good_amps = data[self.template_matrix.amplitudes][tkey].local != 0 + # print(f"LHS IN {tkey}: {data[self.template_matrix.amplitudes][tkey].local[good_amps]}") + self.template_matrix.transpose = False self.template_matrix.det_data = self.det_temp self.template_matrix.det_data_units = self.det_data_units + # Pre-create the temporary LHS detector data if it does not exist + for ob in data.obs: + if self.binning.full_pointing: + if detectors is None: + dets = ob.local_detectors + else: + dets = detectors + exists = ob.detdata.ensure( + self.template_matrix.det_data, + detectors=dets, + create_units=self.template_matrix.det_data_units, + ) + else: + if detectors is None: + first_det = ob.local_detectors[0] + else: + first_det = detectors[0] + exists = ob.detdata.ensure( + self.template_matrix.det_data, + detectors=[first_det], + create_units=self.template_matrix.det_data_units, + ) + self.binning.det_data = self.det_temp self.binning.det_data_units = self.det_data_units @@ -497,6 +531,10 @@ def _exec(self, data, detectors=None, **kwargs): proj_pipe.apply(data, detectors=detectors) + # for tkey in data[self.out].keys(): + # good_amps = data[self.out][tkey].local != 0 + # print(f"LHS OUT {tkey}: {data[self.out][tkey].local[good_amps]}") + log.debug_rank( "MapMaker LHS map scan and amplitude accumulate finished in", comm=comm, diff --git a/src/toast/ops/mapmaker_templates.py b/src/toast/ops/mapmaker_templates.py index adaef5a8a..626969cb7 100644 --- a/src/toast/ops/mapmaker_templates.py +++ b/src/toast/ops/mapmaker_templates.py @@ -8,7 +8,7 @@ import traitlets from astropy import units as u -from ..accelerator import ImplementationType +from ..accelerator import ImplementationType, accel_wait from ..mpi import MPI from ..observation import default_values as defaults from ..pixels import PixelData @@ -287,7 +287,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): # The output template amplitudes exist on host, but are not yet # staged to the accelerator. data[self.amplitudes].accel_create(self.name) - data[self.amplitudes].accel_update_device() + events = data[self.amplitudes].accel_update_device() + accel_wait(events) for d in all_dets: for tmpl in self.templates: @@ -353,14 +354,16 @@ def _finalize(self, data, use_accel=None, **kwargs): if self.transpose: # move amplitudes to host as sync is CPU only if use_accel: - data[self.amplitudes].accel_update_host() + events = data[self.amplitudes].accel_update_host() + accel_wait(events) # Synchronize the result for tmpl in self.templates: if tmpl.enabled: data[self.amplitudes][tmpl.name].sync() # move amplitudes back to GPU as it is NOT finalize's job to move data to host if use_accel: - data[self.amplitudes].accel_update_device() + events = data[self.amplitudes].accel_update_device() + accel_wait(events) # Set the internal initialization to False, so that we are ready to process # completely new data sets. return @@ -899,6 +902,12 @@ def _get_pixel_covariance(self, solve_pixels, solve_weights): solver_cov.apply(self._data, detectors=self._detectors) + good_hits = self._data[self.solver_hits_name].data != 0 + print(f"Solve hits = {self._data[self.solver_hits_name].data[good_hits]}") + + good_pix = self._data[self.solver_cov_name].data != 0 + print(f"Solve covariance = {self._data[self.solver_cov_name].data[good_pix]}") + self._memreport.prefix = "After constructing covariance and hits" self._memreport.apply(self._data) diff --git a/src/toast/ops/mapmaker_utils/CMakeLists.txt b/src/toast/ops/mapmaker_utils/CMakeLists.txt index 0c44b10b0..48548973d 100644 --- a/src/toast/ops/mapmaker_utils/CMakeLists.txt +++ b/src/toast/ops/mapmaker_utils/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/ops/mapmaker_utils ) diff --git a/src/toast/ops/mapmaker_utils/kernels.py b/src/toast/ops/mapmaker_utils/kernels.py index abbe5a1d5..5a00c2a24 100644 --- a/src/toast/ops/mapmaker_utils/kernels.py +++ b/src/toast/ops/mapmaker_utils/kernels.py @@ -8,7 +8,7 @@ from ..._libtoast import build_noise_weighted as libtoast_build_noise_weighted from ..._libtoast import cov_accum_diag_hits as libtoast_cov_accum_diag_hits from ..._libtoast import cov_accum_diag_invnpp as libtoast_cov_accum_diag_invnpp -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import ( build_noise_weighted_numpy, cov_accum_diag_hits_numpy, @@ -22,19 +22,22 @@ cov_accum_diag_invnpp_jax, ) +if use_accel_opencl: + from .kernels_opencl import build_noise_weighted_opencl + @kernel(impl=ImplementationType.COMPILED, name="build_noise_weighted") -def build_noise_weighted_compiled(*args, use_accel=False): +def build_noise_weighted_compiled(*args, use_accel=False, **kwargs): return libtoast_build_noise_weighted(*args, use_accel) @kernel(impl=ImplementationType.COMPILED, name="cov_accum_diag_hits") -def cov_accum_diag_hits_compiled(*args, use_accel=False): +def cov_accum_diag_hits_compiled(*args, use_accel=False, **kwargs): return libtoast_cov_accum_diag_hits(*args, use_accel) @kernel(impl=ImplementationType.COMPILED, name="cov_accum_diag_invnpp") -def cov_accum_diag_invnpp_compiled(*args, use_accel=False): +def cov_accum_diag_invnpp_compiled(*args, use_accel=False, **kwargs): return libtoast_cov_accum_diag_invnpp(*args, use_accel) @@ -56,6 +59,7 @@ def build_noise_weighted( shared_flags, shared_flag_mask, use_accel=False, + **kwargs, ): """Kernel for accumulating the noise weighted map. @@ -115,6 +119,7 @@ def cov_accum_diag_invnpp( scale, invnpp, use_accel=False, + **kwargs, ): """Kernel for accumulating the inverse diagonal pixel noise covariance. @@ -158,6 +163,7 @@ def cov_accum_diag_hits( subpix, hits, use_accel=False, + **kwargs, ): """Kernel for accumulating the hits. diff --git a/src/toast/ops/mapmaker_utils/kernels_jax.py b/src/toast/ops/mapmaker_utils/kernels_jax.py index 64d7a9e1d..ecd191d83 100644 --- a/src/toast/ops/mapmaker_utils/kernels_jax.py +++ b/src/toast/ops/mapmaker_utils/kernels_jax.py @@ -11,7 +11,7 @@ from ...jax.mutableArray import MutableJaxArray from ...utils import AlignedF64, AlignedI64, Logger -# ---------------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- # build_noise_weighted @@ -209,6 +209,7 @@ def build_noise_weighted_jax( shared_flags, shared_flag_mask, use_accel, + **kwargs, ): """ Args: @@ -302,7 +303,9 @@ def cov_accum_diag_hits_inner(nsubpix, submap, subpix, hits): @kernel(impl=ImplementationType.JAX, name="cov_accum_diag_hits") -def cov_accum_diag_hits_jax(nsub, nsubpix, nnz, submap, subpix, hits, use_accel): +def cov_accum_diag_hits_jax( + nsub, nsubpix, nnz, submap, subpix, hits, use_accel, **kwargs +): """ Accumulate hit map. This uses a pointing matrix to accumulate the local pieces of the hit map. @@ -389,7 +392,7 @@ def cov_accum_diag_invnpp_inner(nsubpix, nnz, submap, subpix, weights, scale, in @kernel(impl=ImplementationType.JAX, name="cov_accum_diag_invnpp") def cov_accum_diag_invnpp_jax( - nsub, nsubpix, nnz, submap, subpix, weights, scale, invnpp, use_accel + nsub, nsubpix, nnz, submap, subpix, weights, scale, invnpp, use_accel, **kwargs ): """ Accumulate block diagonal noise covariance. diff --git a/src/toast/ops/mapmaker_utils/kernels_numpy.py b/src/toast/ops/mapmaker_utils/kernels_numpy.py index e8be04252..3f22af722 100644 --- a/src/toast/ops/mapmaker_utils/kernels_numpy.py +++ b/src/toast/ops/mapmaker_utils/kernels_numpy.py @@ -25,6 +25,7 @@ def build_noise_weighted_numpy( shared_flags, shared_flag_mask, use_accel, + **kwargs, ): """ Args: @@ -94,7 +95,9 @@ def build_noise_weighted_numpy( @kernel(impl=ImplementationType.NUMPY, name="cov_accum_diag_hits") -def cov_accum_diag_hits_numpy(nsub, nsubpix, nnz, submap, subpix, hits, use_accel): +def cov_accum_diag_hits_numpy( + nsub, nsubpix, nnz, submap, subpix, hits, use_accel, **kwargs +): """ Accumulate hit map. This uses a pointing matrix to accumulate the local pieces of the hit map. @@ -123,7 +126,7 @@ def cov_accum_diag_hits_numpy(nsub, nsubpix, nnz, submap, subpix, hits, use_acce @kernel(impl=ImplementationType.NUMPY, name="cov_accum_diag_invnpp") def cov_accum_diag_invnpp_numpy( - nsub, nsubpix, nnz, submap, subpix, weights, scale, invnpp, use_accel + nsub, nsubpix, nnz, submap, subpix, weights, scale, invnpp, use_accel, **kwargs ): """ Accumulate block diagonal noise covariance. diff --git a/src/toast/ops/mapmaker_utils/kernels_opencl.cl b/src/toast/ops/mapmaker_utils/kernels_opencl.cl new file mode 100644 index 000000000..5e8cced59 --- /dev/null +++ b/src/toast/ops/mapmaker_utils/kernels_opencl.cl @@ -0,0 +1,146 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + +// https://github.com/KhronosGroup/OpenCL-Docs/blob/main/extensions/cl_ext_float_atomics.asciidoc + +// https://stackoverflow.com/questions/73838432/looking-for-examples-for-atomic-fetch-add-for-float32-in-opencl-3-0 + +// #pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable + +// #if __OPENCL_C_VERSION__ >= CL_VERSION_3_0 +// #pragma OPENCL EXTENSION cl_ext_float_atomics : enable +// #pragma OPENCL EXTENSION cl_ext_double_atomics : enable +// #define atomic_add_double(a,b) atomic_fetch_add((volatile atomic_float *)(a),(b)) + +// atomic_fetch_add(volatile __global A *object, M operand) +// #else + +// inline float atomic_add_double(volatile __global float* address, const float value) { +// float old = value, orig; +// while ((old = atomic_xchg(address, (orig = atomic_xchg(address, 0.0f)) + old)) != 0.0f); +// return orig; +// } +// #endif + +// Based on: +// https://streamhpc.com/blog/2016-02-09/atomic-operations-for-floats-in-opencl-improved/ + +void __attribute__((always_inline)) atomic_add_float( + volatile global float* addr, const float val +) { + union { + uint u32; + float f32; + } next, expected, current; + current.f32 = *addr; + do { + expected.f32 = current.f32; + next.f32 = expected.f32 + val; + current.u32 = atomic_cmpxchg( + (volatile global uint*)addr, expected.u32, next.u32 + ); + } while(current.u32 != expected.u32); +} + +#ifdef cl_khr_int64_base_atomics +#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable +void __attribute__((always_inline)) atomic_add_double( + volatile global double* addr, const double val +) { + union { + ulong u64; + double f64; + } next, expected, current; + current.f64 = *addr; + do { + expected.f64 = current.f64; + next.f64 = expected.f64 + val; + current.u64 = atom_cmpxchg( + (volatile global ulong*)addr, expected.u64, next.u64 + ); + } while(current.u64 != expected.u64); +} +#endif + + + +// Kernels + +__kernel void build_noise_weighted( + int n_det, + long n_sample, + long first_sample, + __global int const * pixels_index, + __global long const * pixels, + __global int const * weights_index, + __global double const * weights, + __global int const * det_data_index, + __global double const * det_data, + __global int const * det_flags_index, + __global unsigned char const * det_flags, + __global unsigned char const * shared_flags, + __global double * zmap, + __global long const * global2local, + __global double const * det_scale, + long nnz, + long npix_submap, + unsigned char det_flag_mask, + unsigned char shared_flag_mask, + unsigned char use_shared_flags, + unsigned char use_det_flags +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + int d_indx = det_data_index[idet]; + int p_indx = pixels_index[idet]; + int w_indx = weights_index[idet]; + int f_indx = det_flags_index[idet]; + + size_t woff = (w_indx * nnz * n_sample) + nnz * isamp; + size_t poff = p_indx * n_sample + isamp; + size_t doff = d_indx * n_sample + isamp; + size_t foff = f_indx * n_sample + isamp; + + long global_submap; + long local_submap_pix; + long local_submap; + long local_pix; + long zoff; + + double scaled_data; + double map_val; + + unsigned char det_check = 0; + if (use_det_flags) { + det_check = det_flags[foff] & det_flag_mask; + } + unsigned char shared_check = 0; + if (use_shared_flags) { + shared_check = shared_flags[isamp] & shared_flag_mask; + } + + if ( + (pixels[poff] >= 0) && + (det_check == 0) && + (shared_check == 0) + ) { + global_submap = (long)(pixels[poff] / npix_submap); + local_submap_pix = pixels[poff] - global_submap * npix_submap; + local_submap = global2local[global_submap]; + local_pix = local_submap * npix_submap + local_submap_pix; + zoff = nnz * local_pix; + + scaled_data = det_scale[d_indx] * det_data[doff]; + for (long i = 0; i < nnz; i++) { + atomic_add_double(&(zmap[zoff + i]), scaled_data * weights[woff + i]); + } + } + + return; +} + + + diff --git a/src/toast/ops/mapmaker_utils/kernels_opencl.py b/src/toast/ops/mapmaker_utils/kernels_opencl.py new file mode 100644 index 000000000..14a64234e --- /dev/null +++ b/src/toast/ops/mapmaker_utils/kernels_opencl.py @@ -0,0 +1,142 @@ +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import numpy as np +import pyopencl as cl + +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + replace_kernel_deps, + clear_kernel_deps, +) + + +@kernel(impl=ImplementationType.OPENCL, name="build_noise_weighted") +def build_noise_weighted_opencl( + global2local, + zmap, + pixels_index, + pixels, + weight_index, + weights, + det_data_index, + det_data, + flag_index, + det_flags, + det_scale, + det_flag_mask, + intervals, + shared_flags, + shared_flag_mask, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + if len(shared_flags) == det_data.shape[1]: + use_shared_flags = np.uint8(1) + else: + use_shared_flags = np.uint8(0) + + if len(det_flags) == det_data.shape[1]: + use_det_flags = np.uint8(1) + else: + use_det_flags = np.uint8(0) + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + kernel = ocl.get_or_build_kernel( + "mapmaker_utils", + "build_noise_weighted", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_global2local = ocl.mem(global2local, device_type=devtype) + dev_zmap = ocl.mem(zmap, device_type=devtype) + dev_pixels = ocl.mem(pixels, device_type=devtype) + dev_weights = ocl.mem(weights, device_type=devtype) + dev_det_data = ocl.mem(det_data, device_type=devtype) + dev_det_flags = ocl.mem(det_flags, device_type=devtype) + dev_shared_flags = ocl.mem(shared_flags, device_type=devtype) + + # Allocate temporary device arrays + + dev_pixels_index = ocl.mem_to_device(pixels_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_pixels_index.events) + + dev_weight_index = ocl.mem_to_device(weight_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_weight_index.events) + + dev_det_data_index = ocl.mem_to_device( + det_data_index, device_type=devtype, async_=True + ) + add_kernel_deps(state, obs_name, dev_det_data_index.events) + + dev_flag_index = ocl.mem_to_device(flag_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_flag_index.events) + + dev_det_scale = ocl.mem_to_device(det_scale, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_det_scale.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + print(f"BLDNSEW: {obs_name} got wait_for = {wait_for}", flush=True) + print(f"BLDNSEW: {obs_name} pixels={dev_pixels}, weights={dev_weights}, zmap={dev_zmap}", flush=True) + + n_det = len(det_data_index) + n_samp = weights.shape[1] + nnz = weights.shape[2] + n_pix_submap = zmap.shape[1] + + for intr in intervals: + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = kernel( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_pixels_index.data, + dev_pixels.data, + dev_weight_index.data, + dev_weights.data, + dev_det_data_index.data, + dev_det_data.data, + dev_flag_index.data, + dev_det_flags.data, + dev_shared_flags.data, + dev_zmap.data, + dev_global2local.data, + dev_det_scale.data, + np.int64(nnz), + np.int64(n_pix_submap), + np.uint8(det_flag_mask), + np.uint8(shared_flag_mask), + use_shared_flags, + use_det_flags, + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(pixels_index, device_type=devtype) + ocl.mem_remove(weight_index, device_type=devtype) + ocl.mem_remove(det_data_index, device_type=devtype) + ocl.mem_remove(flag_index, device_type=devtype) + ocl.mem_remove(det_scale, device_type=devtype) diff --git a/src/toast/ops/mapmaker_utils/mapmaker_utils.py b/src/toast/ops/mapmaker_utils/mapmaker_utils.py index 42f38e90f..0fbf6bcf3 100644 --- a/src/toast/ops/mapmaker_utils/mapmaker_utils.py +++ b/src/toast/ops/mapmaker_utils/mapmaker_utils.py @@ -6,7 +6,7 @@ import traitlets from astropy import units as u -from ...accelerator import ImplementationType +from ...accelerator import ImplementationType, accel_wait from ...covariance import covariance_invert from ...mpi import MPI from ...observation import default_values as defaults @@ -201,6 +201,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): hits.raw, impl=implementation, use_accel=use_accel, + **kwargs, ) return @@ -510,6 +511,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): invcov.raw, impl=implementation, use_accel=use_accel, + **kwargs, ) return @@ -760,7 +762,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): f"Operator {self.name} zmap: copy host to device", comm=data.comm.comm_group, ) - zmap.accel_update_device() + events = zmap.accel_update_device() + accel_wait(events) else: log.verbose_rank( f"Operator {self.name} zmap: already in use on device", @@ -773,7 +776,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): f"Operator {self.name} zmap: update host from device", comm=data.comm.comm_group, ) - zmap.accel_update_host() + events = zmap.accel_update_host() + accel_wait(events) # # DEBUGGING # restore_dev = False @@ -844,6 +848,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): else: shared_flag_data = np.zeros(1, dtype=np.uint8) + print(f"BLD {ob.name} {self.pixels}={ob.detdata[self.pixels].data}, {self.weights}={ob.detdata[self.weights].data}, {self.det_data}={ob.detdata[self.det_data].data}, {self.view}={ob.intervals[self.view].data}") + build_noise_weighted( zmap.distribution.global_submap_to_local, zmap.data, @@ -862,6 +868,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): self.shared_flag_mask, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) # # DEBUGGING @@ -895,7 +903,8 @@ def _finalize(self, data, use_accel=None, **kwargs): comm=data.comm.comm_group, ) restore_device = True - data[self.zmap].accel_update_host() + events = data[self.zmap].accel_update_host() + accel_wait(events) if self.sync_type == "alltoallv": data[self.zmap].sync_alltoallv() else: @@ -923,7 +932,8 @@ def _finalize(self, data, use_accel=None, **kwargs): f"Operator {self.name} finalize calling zmap update device", comm=data.comm.comm_group, ) - data[self.zmap].accel_update_device() + events = data[self.zmap].accel_update_device() + accel_wait(events) return def _requires(self): @@ -954,6 +964,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): @@ -1153,7 +1164,7 @@ def _exec(self, data, detectors=None, **kwargs): pixel_pointing=self.pixel_pointing, save_pointing=self.save_pointing, ) - pix_dist.apply(data) + pix_dist.apply(data, **kwargs) # Check if map domain products exist and are consistent. The hits # and inverse covariance accumulation operators support multiple @@ -1227,7 +1238,7 @@ def _exec(self, data, detectors=None, **kwargs): build_invcov, ] - pipe_out = accum.apply(data, detectors=detectors) + pipe_out = accum.apply(data, detectors=detectors, **kwargs) # Optionally, store the inverse covariance if self.inverse_covariance is not None: diff --git a/src/toast/ops/noise_weight/CMakeLists.txt b/src/toast/ops/noise_weight/CMakeLists.txt index 6dcd12e08..5961e448c 100644 --- a/src/toast/ops/noise_weight/CMakeLists.txt +++ b/src/toast/ops/noise_weight/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/ops/noise_weight ) diff --git a/src/toast/ops/noise_weight/kernels.py b/src/toast/ops/noise_weight/kernels.py index 682e88377..fb69fcf34 100644 --- a/src/toast/ops/noise_weight/kernels.py +++ b/src/toast/ops/noise_weight/kernels.py @@ -5,15 +5,18 @@ import numpy as np from ..._libtoast import noise_weight as libtoast_noise_weight -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import noise_weight_numpy if use_accel_jax: from .kernels_jax import noise_weight_jax +if use_accel_opencl: + from .kernels_opencl import noise_weight_opencl + @kernel(impl=ImplementationType.COMPILED, name="noise_weight") -def noise_weight_compiled(*args, use_accel=False): +def noise_weight_compiled(*args, use_accel=False, **kwargs): return libtoast_noise_weight(*args, use_accel) @@ -24,6 +27,7 @@ def noise_weight( intervals, detector_weights, use_accel=False, + **kwargs, ): """Kernel for applying noise weights to detector timestreams. diff --git a/src/toast/ops/noise_weight/kernels_jax.py b/src/toast/ops/noise_weight/kernels_jax.py index 893c04491..dd6fb38a4 100644 --- a/src/toast/ops/noise_weight/kernels_jax.py +++ b/src/toast/ops/noise_weight/kernels_jax.py @@ -97,7 +97,9 @@ def noise_weight_interval( @kernel(impl=ImplementationType.JAX, name="noise_weight") -def noise_weight_jax(det_data, det_data_index, intervals, detector_weights, use_accel): +def noise_weight_jax( + det_data, det_data_index, intervals, detector_weights, use_accel, **kwargs +): """ multiplies det_data by the weighs in detector_weights diff --git a/src/toast/ops/noise_weight/kernels_numpy.py b/src/toast/ops/noise_weight/kernels_numpy.py index 88949a064..78e96f054 100644 --- a/src/toast/ops/noise_weight/kernels_numpy.py +++ b/src/toast/ops/noise_weight/kernels_numpy.py @@ -14,6 +14,7 @@ def noise_weight_numpy( intervals, detector_weights, use_accel, + **kwargs, ): # Iterates over detectors and intervals n_det = det_data_index.size diff --git a/src/toast/ops/noise_weight/kernels_opencl.cl b/src/toast/ops/noise_weight/kernels_opencl.cl new file mode 100644 index 000000000..3dae1e1f1 --- /dev/null +++ b/src/toast/ops/noise_weight/kernels_opencl.cl @@ -0,0 +1,21 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + +__kernel void noise_weight( + int n_det, + long n_sample, + long first_sample, + __global double const * weights, + __global int const * det_data_index, + __global double * det_data +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + int didx = det_data_index[idet]; + + det_data[didx * n_sample + isamp] *= weights[idet]; +} + diff --git a/src/toast/ops/noise_weight/kernels_opencl.py b/src/toast/ops/noise_weight/kernels_opencl.py new file mode 100644 index 000000000..e50c6925e --- /dev/null +++ b/src/toast/ops/noise_weight/kernels_opencl.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import numpy as np +import pyopencl as cl + +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + clear_kernel_deps, +) + + +@kernel(impl=ImplementationType.OPENCL, name="noise_weight") +def noise_weight_opencl( + det_data, + det_data_index, + intervals, + detector_weights, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + # Get our kernel + noise_weight = ocl.get_or_build_kernel( + "noise_weight", + "noise_weight", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_det_data = ocl.mem(det_data, device_type=devtype) + + # Allocate temporary arrays and copy to device + dev_det_data_index = ocl.mem_to_device( + det_data_index, device_type=devtype, async_=True + ) + add_kernel_deps(state, obs_name, dev_det_data_index.events) + dev_det_weights = ocl.mem_to_device( + detector_weights, device_type=devtype, async_=True + ) + add_kernel_deps(state, obs_name, dev_det_weights.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"NSEWEIGHT: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = len(det_data_index) + n_samp = det_data.shape[1] + for intr in intervals: + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = noise_weight( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_det_weights.data, + dev_det_data_index.data, + dev_det_data.data, + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(det_data_index, device_type=devtype) + ocl.mem_remove(detector_weights, device_type=devtype) diff --git a/src/toast/ops/noise_weight/noise_weight.py b/src/toast/ops/noise_weight/noise_weight.py index 9b08c0cd1..6020913a5 100644 --- a/src/toast/ops/noise_weight/noise_weight.py +++ b/src/toast/ops/noise_weight/noise_weight.py @@ -128,6 +128,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): detector_weights, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) # Update the units of the output @@ -157,6 +159,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): diff --git a/src/toast/ops/pipeline.py b/src/toast/ops/pipeline.py index a4f942c35..d58b850a0 100644 --- a/src/toast/ops/pipeline.py +++ b/src/toast/ops/pipeline.py @@ -4,7 +4,22 @@ import traitlets -from ..accelerator import ImplementationType, accel_enabled, use_hybrid_pipelines +from ..accelerator import ( + ImplementationType, + accel_enabled, + accel_wait, + use_hybrid_pipelines, + use_accel_jax, + use_accel_omp, + use_accel_opencl, +) + +if use_accel_opencl: + import pyopencl as cl + from ..opencl import ( + OpenCL, add_kernel_deps, get_kernel_deps, replace_kernel_deps, clear_kernel_deps, + ) + from ..data import Data from ..timing import function_timer from ..traits import Bool, Int, List, trait_docs @@ -138,6 +153,10 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): if det_mask is None: det_mask = 0 + # Allow different accelerator implementations to accept and return + # state information. + kernel_state = dict() + if len(self.detector_sets) == 1 and self.detector_sets[0] == "ALL": # Run the operators with all detectors at once for op in self.operators: @@ -146,6 +165,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): data, detectors=None, pipe_accel=pipe_accel, + state=kernel_state, + **kwargs, ) elif len(self.detector_sets) == 1 and self.detector_sets[0] == "SINGLE": # Get superset of detectors across all observations @@ -168,6 +189,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): data, detectors=dets, pipe_accel=pipe_accel, + state=kernel_state, + **kwargs, ) else: # We have explicit detector sets @@ -190,8 +213,25 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): data, detectors=selected_set, pipe_accel=pipe_accel, + state=kernel_state, + **kwargs, ) + # Post-processing depending on accelerator technique in use. Here we + # have access to the kernel state dictionary that has been updated + # along the way by our chain of operators. + if use_accel_opencl: + # Wait for all kernels to complete + for obname, events in kernel_state.items(): + for ev in events: + # print(f"DBG: wait for ev {ev}", flush=True) + ev.wait() + elif use_accel_jax: + pass + elif use_accel_omp: + pass + del kernel_state + # notify user of device->host data movements introduced by CPU operators if (self._unstaged_data is not None) and (not self._unstaged_data.is_empty()): cpu_ops = { @@ -204,7 +244,9 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): ) @function_timer - def _exec_operator(self, op, data, detectors, pipe_accel): + def _exec_operator( + self, op, data, detectors, pipe_accel=False, state=None, **kwargs + ): """Runs an operator, dealing with data movement to/from device if needed.""" # For this operator, we run on the accelerator if the pipeline has some # operators enabled and if this operator supports it. @@ -217,7 +259,14 @@ def _exec_operator(self, op, data, detectors, pipe_accel): msg += " with ALL dets" log.verbose(msg) - # Ensures data is where it should be for this operator + print(f"PIPE {op.name} begin state = {state}", flush=True) + + # Get the queue for this process on the default device + if use_accel_opencl: + ocl = OpenCL() + queue = ocl.queue() + + # Ensure data is where it should be for this operator if self._staged_data is not None: requires = SetDict(op.requires()) if run_accel: @@ -230,7 +279,19 @@ def _exec_operator(self, op, data, detectors, pipe_accel): msg += f"Staging objects {requires}" log.verbose(msg) data.accel_create(requires) - data.accel_update_device(requires) + + stage_events = data.accel_update_device(requires) + print(f"PIPE stage data events = {stage_events}") + if use_accel_opencl: + # Create a marker event that depends on the per-observation + # data transfer. + for obs_name, obs_ev in stage_events.items(): + add_kernel_deps( + state, + obs_name, + cl.enqueue_marker(queue, obs_ev), + ) + # Update our record of data on device self._unstaged_data -= requires self._staged_data |= requires @@ -248,7 +309,23 @@ def _exec_operator(self, op, data, detectors, pipe_accel): msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} " msg += f"Un-staging objects {requires}" log.verbose(msg) - data.accel_update_host(requires) + + unstage_events = data.accel_update_host(requires) + print(f"PIPE unstage data events = {unstage_events}") + if use_accel_opencl: + # Create a marker event that depends on the per-observation + # data transfer. + for obs_name, obs_ev in unstage_events.items(): + add_kernel_deps( + state, + obs_name, + cl.enqueue_marker(queue, obs_ev), + ) + # We have to wait for all the data to get back to the host + # before running our operator on the host. + for obs_name, obs_evs in state.items(): + accel_wait(obs_evs) + # Update our record of data on the device self._staged_data -= requires self._unstaged_data |= requires # union @@ -258,8 +335,9 @@ def _exec_operator(self, op, data, detectors, pipe_accel): msg = f"Proc ({data.comm.world_rank}, {data.comm.group_rank}) {self} " msg += f"AFTER staged = {self._staged_data}, unstaged = {self._unstaged_data}" log.verbose(msg) - # runs operator - op.exec(data, detectors=detectors, use_accel=run_accel) + + print(f"PIPE {op.name} state = {state}", flush=True) + op.exec(data, detectors=detectors, use_accel=run_accel, state=state, **kwargs) @function_timer def _finalize(self, data, use_accel=None, **kwargs): @@ -292,7 +370,9 @@ def _finalize(self, data, use_accel=None, **kwargs): provides = SetDict(self.provides()) provides &= self._staged_data # intersection log.verbose(f"{pstr} {self} copying out accel data outputs: {provides}") - data.accel_update_host(provides) + unstage_events = data.accel_update_host(provides) + for obs_name, obs_ev in unstage_events.items(): + accel_wait(obs_ev) # deleting all data on device log.verbose(f"{pstr} {self} deleting accel data: {self._staged_data}") data.accel_delete(self._staged_data) @@ -303,10 +383,10 @@ def _finalize(self, data, use_accel=None, **kwargs): def _pipe_accel(self, use_accel): if (use_accel is None) and accel_enabled(): - # Only allows hybrid pipelines if the environement variable and pipeline agree to it - # (they both default to True) + # Only allows hybrid pipelines if the environement variable and pipeline + # support it (they both default to True) use_hybrid = self.use_hybrid and use_hybrid_pipelines - # can we run this pipelines on accelerator + # can we run this pipeline on accelerator supports_accel = ( self._supports_accel_partial() if use_hybrid else self._supports_accel() ) @@ -358,6 +438,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, } for op in self.operators: implementations.intersection_update(op.implementations()) diff --git a/src/toast/ops/pixels_healpix/CMakeLists.txt b/src/toast/ops/pixels_healpix/CMakeLists.txt index 953ab69c4..d2d31314c 100644 --- a/src/toast/ops/pixels_healpix/CMakeLists.txt +++ b/src/toast/ops/pixels_healpix/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/ops/pixels_healpix ) diff --git a/src/toast/ops/pixels_healpix/kernels.py b/src/toast/ops/pixels_healpix/kernels.py index f17a26896..b09692832 100644 --- a/src/toast/ops/pixels_healpix/kernels.py +++ b/src/toast/ops/pixels_healpix/kernels.py @@ -7,16 +7,50 @@ from ... import qarray as qa from ..._libtoast import pixels_healpix as libtoast_pixels_healpix -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import pixels_healpix_numpy if use_accel_jax: from .kernels_jax import pixels_healpix_jax +if use_accel_opencl: + from .kernels_opencl import pixels_healpix_opencl + @kernel(impl=ImplementationType.COMPILED, name="pixels_healpix") -def pixels_healpix_compiled(*args, use_accel=False): - return libtoast_pixels_healpix(*args, use_accel) +def pixels_healpix_compiled( + quat_index, + quats, + shared_flags, + shared_flag_mask, + pixel_index, + pixels, + intervals, + hit_submaps, + n_pix_submap, + nside, + nest, + compute_submaps, + use_accel=False, + **kwargs, +): + if hit_submaps is None: + hit_submaps = np.zeros(1, dtype=np.uint8) + return libtoast_pixels_healpix( + quat_index, + quats, + shared_flags, + shared_flag_mask, + pixel_index, + pixels, + intervals, + hit_submaps, + n_pix_submap, + nside, + nest, + compute_submaps, + use_accel, + ) @kernel(impl=ImplementationType.DEFAULT) @@ -32,7 +66,9 @@ def pixels_healpix( n_pix_submap, nside, nest, + compute_submaps, use_accel=False, + **kwargs, ): """Kernel for computing healpix pixelization. @@ -46,17 +82,20 @@ def pixels_healpix( detector. pixels (array): The array of detector pixels for each sample. intervals (array): The array of sample intervals. - hit_submaps (array): Array of bytes to set to 1 if the submap is hit + hit_submaps (device array): Array of bytes to set to 1 if the submap is hit and zero if not hit. n_pix_submap (int): The number of pixels in a submap. nside (int): The Healpix NSIDE of the pixelization. nest (bool): If true, use NESTED ordering, else use RING. + compute_submaps (bool): If True, compute the hit submaps. use_accel (bool): Whether to use the accelerator for this call (if supported). Returns: None """ + if hit_submaps is None: + hit_submaps = np.zeros(1, dtype=np.uint8) return libtoast_pixels_healpix( quat_index, quats, @@ -69,5 +108,6 @@ def pixels_healpix( n_pix_submap, nside, nest, + compute_submaps, use_accel, ) diff --git a/src/toast/ops/pixels_healpix/kernels_jax.py b/src/toast/ops/pixels_healpix/kernels_jax.py index 8b06f5506..a0209dc88 100644 --- a/src/toast/ops/pixels_healpix/kernels_jax.py +++ b/src/toast/ops/pixels_healpix/kernels_jax.py @@ -14,7 +14,7 @@ def pixels_healpix_inner( - quats, use_flags, flag, flag_mask, hit_submaps, n_pix_submap, hpix, nest + quats, use_flags, flag, flag_mask, hit_submaps, n_pix_submap, hpix, nest, compute_submaps, ): """ Compute the healpix pixel indices for the detectors. @@ -41,19 +41,30 @@ def pixels_healpix_inner( pixel = healpix.zphi2ring(hpix, phi, region, z, rtz) # compute sub map - sub_map = pixel // n_pix_submap + if compute_submaps: + sub_map = pixel // n_pix_submap # applies the flags if use_flags: is_flagged = (flag & flag_mask) != 0 pixel = jnp.where(is_flagged, -1, pixel) - hit_submap = jnp.where(is_flagged, hit_submaps[sub_map], 1) + if compute_submaps: + hit_submap = jnp.where(is_flagged, hit_submaps[sub_map], 1) + else: + hit_submap = 0 else: - hit_submap = 1 + if compute_submaps: + hit_submap = 1 + else: + hit_submap = 0 return pixel, sub_map, hit_submap +# FIXME: Do we need two different imaps here for the case of +# with / without hit submap calculation? + + # maps over samples and detectors pixels_healpix_inner = imap( pixels_healpix_inner, @@ -66,6 +77,7 @@ def pixels_healpix_inner( "n_pix_submap": int, "hpix": healpix.HPIX_JAX, "nest": bool, + "compute_submaps": bool, "interval_starts": ["n_intervals"], "interval_ends": ["n_intervals"], "intervals_max_length": int, @@ -90,6 +102,7 @@ def pixels_healpix_interval( n_pix_submap, nside, nest, + compute_submaps, interval_starts, interval_ends, intervals_max_length, @@ -122,12 +135,6 @@ def pixels_healpix_interval( # create hpix object hpix = healpix.HPIX_JAX(nside) - # extract indexes - quats_indexed = quats[quat_index, :, :] - pixels_indexed = pixels[pixel_index, :] - dummy_sub_map = jnp.zeros_like(pixels_indexed) - dummy_hit_submaps = hit_submaps[dummy_sub_map] - # should we use flags? use_flags = flag_mask != 0 n_samp = pixels.shape[1] @@ -135,8 +142,18 @@ def pixels_healpix_interval( flags = jnp.empty(shape=(n_samp,)) use_flags = False + # extract indexes + quats_indexed = quats[quat_index, :, :] + pixels_indexed = pixels[pixel_index, :] + dummy_sub_map = jnp.zeros_like(pixels_indexed) + if hit_submaps is None: + dummy_hit_submaps = None + outputs = (pixels_indexed, dummy_sub_map) + else: + dummy_hit_submaps = hit_submaps[dummy_sub_map] + outputs = (pixels_indexed, dummy_sub_map, dummy_hit_submaps) + # does the computation - outputs = (pixels_indexed, dummy_sub_map, dummy_hit_submaps) new_pixels_indexed, sub_map, new_hit_submaps = pixels_healpix_inner( quats_indexed, use_flags, @@ -146,6 +163,7 @@ def pixels_healpix_interval( n_pix_submap, hpix, nest, + compute_submaps, interval_starts, interval_ends, intervals_max_length, @@ -154,7 +172,8 @@ def pixels_healpix_interval( # updates results and returns pixels = pixels.at[pixel_index, :].set(new_pixels_indexed) - hit_submaps = hit_submaps.at[sub_map].set(new_hit_submaps) + if hit_submaps is not None: + hit_submaps = hit_submaps.at[sub_map].set(new_hit_submaps) return pixels, hit_submaps @@ -166,6 +185,7 @@ def pixels_healpix_interval( "n_pix_submap", "nside", "nest", + "compute_submaps", "intervals_max_length", ], donate_argnums=[5, 6], @@ -185,7 +205,9 @@ def pixels_healpix_jax( n_pix_submap, nside, nest, + compute_submaps, use_accel, + **kwargs, ): """ Compute the healpix pixel indices for the detectors. @@ -230,6 +252,7 @@ def pixels_healpix_jax( n_pix_submap, nside, nest, + compute_submaps, intervals.first, intervals.last, intervals_max_length, diff --git a/src/toast/ops/pixels_healpix/kernels_numpy.py b/src/toast/ops/pixels_healpix/kernels_numpy.py index 6861ed910..e8c830855 100644 --- a/src/toast/ops/pixels_healpix/kernels_numpy.py +++ b/src/toast/ops/pixels_healpix/kernels_numpy.py @@ -22,7 +22,9 @@ def pixels_healpix_numpy( n_pix_submap, nside, nest, + compute_submaps, use_accel=False, + **kwargs, ): zaxis = np.array([0, 0, 1], dtype=np.float64) for idet in range(len(quat_index)): @@ -44,6 +46,7 @@ def pixels_healpix_numpy( else: good = (shared_flags[samples] & shared_flag_mask) == 0 bad = np.logical_not(good) - sub_maps = pixels[pidx][samples][good] // n_pix_submap - hit_submaps[sub_maps] = 1 + if compute_submaps: + sub_maps = pixels[pidx][samples][good] // n_pix_submap + hit_submaps[sub_maps] = 1 pixels[pidx][samples][bad] = -1 diff --git a/src/toast/ops/pixels_healpix/kernels_opencl.cl b/src/toast/ops/pixels_healpix/kernels_opencl.cl new file mode 100644 index 000000000..22e1a9548 --- /dev/null +++ b/src/toast/ops/pixels_healpix/kernels_opencl.cl @@ -0,0 +1,637 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + +// 2/3 +#define TWOTHIRDS 0.66666666666666666667 + +// Double precision machine epsilon +#define EPS 2.220446e-16 + +// Internal functions working with a single sample. We use hpix_* prefix for these. + +double hpix_fmod(double in, double mod) { + double div = in / mod; + return mod * (div - (double)((long)div)); +} + +void hpix_qa_rotate( + double const * q_in, + double const * v_in, + double * v_out +) { + // The input quaternion has already been normalized on the host. + double xw = q_in[3] * q_in[0]; + double yw = q_in[3] * q_in[1]; + double zw = q_in[3] * q_in[2]; + double x2 = -q_in[0] * q_in[0]; + double xy = q_in[0] * q_in[1]; + double xz = q_in[0] * q_in[2]; + double y2 = -q_in[1] * q_in[1]; + double yz = q_in[1] * q_in[2]; + double z2 = -q_in[2] * q_in[2]; + + v_out[0] = 2 * ((y2 + z2) * v_in[0] + (xy - zw) * v_in[1] + + (yw + xz) * v_in[2]) + v_in[0]; + + v_out[1] = 2 * ((zw + xy) * v_in[0] + (x2 + z2) * v_in[1] + + (yz - xw) * v_in[2]) + v_in[1]; + + v_out[2] = 2 * ((xz - yw) * v_in[0] + (xw + yz) * v_in[1] + + (x2 + y2) * v_in[2]) + v_in[2]; + + return; +} + +long hpix_xy2pix(__global long const * utab, long x, long y) { + return utab[x & 0xff] | (utab[(x >> 8) & 0xff] << 16) | + (utab[(x >> 16) & 0xff] << 32) | + (utab[(x >> 24) & 0xff] << 48) | + (utab[y & 0xff] << 1) | (utab[(y >> 8) & 0xff] << 17) | + (utab[(y >> 16) & 0xff] << 33) | + (utab[(y >> 24) & 0xff] << 49); +} + +void hpix_pix2xy(__global long const * ctab, long pix, long * x, long * y) { + long raw; + raw = (pix & 0x5555ull) | ((pix & 0x55550000ull) >> 15) | + ((pix & 0x555500000000ull) >> 16) | + ((pix & 0x5555000000000000ull) >> 31); + (*x) = ctab[raw & 0xff] | (ctab[(raw >> 8) & 0xff] << 4) | + (ctab[(raw >> 16) & 0xff] << 16) | + (ctab[(raw >> 24) & 0xff] << 20); + raw = ((pix & 0xaaaaull) >> 1) | ((pix & 0xaaaa0000ull) >> 16) | + ((pix & 0xaaaa00000000ull) >> 17) | + ((pix & 0xaaaa000000000000ull) >> 32); + (*y) = ctab[raw & 0xff] | (ctab[(raw >> 8) & 0xff] << 4) | + (ctab[(raw >> 16) & 0xff] << 16) | + (ctab[(raw >> 24) & 0xff] << 20); + return; +} + +void hpix_vec2zphi( + double const * vec, + double * phi, + int * region, + double * z, + double * rtz +) { + // region encodes BOTH the sign of Z and whether its + // absolute value is greater than 2/3. + (*z) = vec[2]; + double za = fabs((*z)); + int itemp = ((*z) > 0.0) ? 1 : -1; + (*region) = (za <= TWOTHIRDS) ? itemp : itemp + itemp; + (*rtz) = sqrt(3.0 * (1.0 - za)); + (*phi) = atan2(vec[1], vec[0]); + return; +} + +long hpix_zphi2nest( + long nside, + long factor, + __global long const * utab, + double phi, + int region, + double z, + double rtz +) { + double tol = 10.0 * EPS; + double phi_mod = hpix_fmod(phi, 2 * M_PI); + if ((phi_mod < tol) && (phi_mod > -tol)) { + phi_mod = 0.0; + } + double tt = (phi_mod >= 0.0) ? phi_mod * M_2_PI : phi_mod * M_2_PI + 4.0; + long x; + long y; + double temp1; + double temp2; + long jp; + long jm; + long ifp; + long ifm; + long face; + long ntt; + double tp; + + double dnside = (double)nside; + double halfnside = 0.5 * dnside; + double tqnside = 0.75 * dnside; + long nsideminusone = nside - 1; + + if ((region == 1) || (region == -1)) { + temp1 = halfnside + dnside * tt; + temp2 = tqnside * z; + + jp = (long)(temp1 - temp2); + jm = (long)(temp1 + temp2); + + ifp = jp >> factor; + ifm = jm >> factor; + + if (ifp == ifm) { + face = (ifp == 4) ? (long)4 : ifp + 4; + } else if (ifp < ifm) { + face = ifp; + } else { + face = ifm + 8; + } + + x = jm & nsideminusone; + y = nsideminusone - (jp & nsideminusone); + } else { + ntt = (long)tt; + + tp = tt - (double)ntt; + + temp1 = dnside * rtz; + + jp = (long)(tp * temp1); + jm = (long)((1.0 - tp) * temp1); + + if (jp >= nside) { + jp = nsideminusone; + } + if (jm >= nside) { + jm = nsideminusone; + } + + if (z >= 0) { + face = ntt; + x = nsideminusone - jm; + y = nsideminusone - jp; + } else { + face = ntt + 8; + x = jp; + y = jm; + } + } + long sipf = hpix_xy2pix(utab, x, y); + long pix = sipf + (face << (2 * factor)); + return pix; +} + +long hpix_zphi2ring( + long nside, + long factor, + double phi, + int region, + double z, + double rtz +) { + double tol = 10.0 * EPS; + double phi_mod = hpix_fmod(phi, 2 * M_PI); + if ((phi_mod < tol) && (phi_mod > -tol)) { + phi_mod = 0.0; + } + double tt = (phi_mod >= 0.0) ? phi_mod * M_2_PI : phi_mod * M_2_PI + 4.0; + double tp; + long longpart; + double temp1; + double temp2; + long jp; + long jm; + long ip; + long ir; + long kshift; + long pix; + + double dnside = (double)nside; + long fournside = 4 * nside; + double halfnside = 0.5 * dnside; + double tqnside = 0.75 * dnside; + long nsideplusone = nside + 1; + long ncap = 2 * (nside * nside - nside); + long npix = 12 * nside * nside; + + if ((region == 1) || (region == -1)) { + temp1 = halfnside + dnside * tt; + temp2 = tqnside * z; + + jp = (long)(temp1 - temp2); + jm = (long)(temp1 + temp2); + + ir = nsideplusone + jp - jm; + kshift = 1 - (ir & 1); + + ip = (jp + jm - nside + kshift + 1) >> 1; + ip = ip % fournside; + + pix = ncap + ((ir - 1) * fournside + ip); + } else { + tp = tt - floor(tt); + + temp1 = dnside * rtz; + + jp = (long)(tp * temp1); + jm = (long)((1.0 - tp) * temp1); + ir = jp + jm + 1; + ip = (long)(tt * (double)ir); + longpart = (long)(ip / (4 * ir)); + ip -= longpart; + + pix = (region > 0) ? (2 * ir * (ir - 1) + ip) + : (npix - 2 * ir * (ir + 1) + ip); + } + return pix; +} + +void hpix_ang2vec(double theta, double phi, double * vec) { + double sintheta = sin(theta); + vec[0] = sintheta * cos(phi); + vec[1] = sintheta * sin(phi); + vec[2] = cos(theta); + return; +} + +void hpix_vec2ang( + double const * vec, + double * theta, + double * phi +) { + double norm = 1.0 / sqrt( + vec[0] * vec[0] + vec[1] * vec[1] + vec[2] * vec[2] + ); + (*theta) = acos(vec[2] * norm); + unsigned char small_theta = (fabs((*theta)) <= EPS) ? 1 : 0; + unsigned char big_theta = (fabs(M_PI - (*theta)) <= EPS) ? 1 : 0; + double phitemp = atan2(vec[1], vec[0]); + (*phi) = (phitemp < 0) ? phitemp + 2 * M_PI : phitemp; + (*phi) = (small_theta || big_theta) ? 0.0 : (*phi); + return; +} + +void hpix_theta2z( + double theta, + int * region, + double * z, + double * rtz +) { + (*z) = cos(theta); + double za = fabs((*z)); + int itemp = ((*z) > 0.0) ? 1 : -1; + (*region) = (za <= TWOTHIRDS) ? itemp : itemp + itemp; + double work = 3.0 * (1.0 - za); + (*rtz) = sqrt(work); + return; +} + +long hpix_ang2nest( + long nside, + long factor, + __global long const * utab, + double theta, + double phi +) { + double z; + double rtz; + int region; + long pix; + hpix_theta2z(theta, ®ion, &z, &rtz); + return hpix_zphi2nest(nside, factor, utab, phi, region, z, rtz); +} + +long hpix_ang2ring( + long nside, + long factor, + __global long const * utab, + double theta, + double phi +) { + double z; + double rtz; + int region; + hpix_theta2z(theta, ®ion, &z, &rtz); + return hpix_zphi2ring(nside, factor, phi, region, z, rtz); +} + +long hpix_vec2nest( + long nside, + long factor, + __global long const * utab, + double const * vec +) { + double z; + double phi; + double rtz; + int region; + hpix_vec2zphi(vec, &phi, ®ion, &z, &rtz); + return hpix_zphi2nest(nside, factor, utab, phi, region, z, rtz); +} + +long hpix_vec2ring( + long nside, + long factor, + __global long const * utab, + double const * vec +) { + double z; + double phi; + double rtz; + int region; + hpix_vec2zphi(vec, &phi, ®ion, &z, &rtz); + return hpix_zphi2ring(nside, factor, phi, region, z, rtz); +} + +long hpix_ring2nest( + long nside, + long factor, + long npix, + long ncap, + __global long const * utab, + long ringpix +) { + long fc; + long x, y; + long nr; + long kshift; + long iring; + long iphi; + long tmp; + long ip; + long ire, irm; + long ifm, ifp; + long irt, ipt; + const long hpix_jr[] = {2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4}; + const long hpix_jp[] = {1, 3, 5, 7, 0, 2, 4, 6, 1, 3, 5, 7}; + long nestpix; + + if (ringpix < ncap) { + iring = (long)(0.5 * (1.0 + sqrt((double)(1 + 2 * ringpix)))); + iphi = (ringpix + 1) - 2 * iring * (iring - 1); + kshift = 0; + nr = iring; + fc = 0; + tmp = iphi - 1; + if (tmp >= (2 * iring)) { + fc = 2; + tmp -= 2 * iring; + } + if (tmp >= iring) { + ++fc; + } + } else if (ringpix < (npix - ncap)) { + ip = ringpix - ncap; + iring = (ip >> (factor + 2)) + nside; + iphi = (ip & (4 * nside - 1)) + 1; + kshift = (iring + nside) & 1; + nr = nside; + ire = iring - nside + 1; + irm = 2 * nside + 2 - ire; + ifm = (iphi - (ire / 2) + nside - 1) >> factor; + ifp = (iphi - (irm / 2) + nside - 1) >> factor; + if (ifp == ifm) { + // faces 4 to 7 + fc = (ifp == 4) ? 4 : ifp + 4; + } else if (ifp < ifm) { + // (half-)faces 0 to 3 + fc = ifp; + } else { + // (half-)faces 8 to 11 + fc = ifm + 8; + } + } else { + ip = npix - ringpix; + iring = (long)(0.5 * (1.0 + sqrt((double)(2 * ip - 1)))); + iphi = 4 * iring + 1 - (ip - 2 * iring * (iring - 1)); + kshift = 0; + nr = iring; + iring = 4 * nside - iring; + fc = 8; + tmp = iphi - 1; + if (tmp >= (2 * nr)) { + fc = 10; + tmp -= 2 * nr; + } + if (tmp >= nr) { + ++fc; + } + } + irt = iring - hpix_jr[fc] * nside + 1; + ipt = 2 * iphi - hpix_jp[fc] * nr - kshift - 1; + if (ipt >= 2 * nside) { + ipt -= 8 * nside; + } + x = (ipt - irt) >> 1; + y = (-(ipt + irt)) >> 1; + nestpix = hpix_xy2pix(utab, x, y); + nestpix += (fc << (2 * factor)); + return nestpix; +} + +long hpix_nest2ring( + long nside, + long factor, + long npix, + long ncap, + __global long const * ctab, + long nestpix +) { + long fc; + long x, y; + long jr; + long jp; + long nr; + long kshift; + long n_before; + const long hpix_jr[] = {2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4}; + const long hpix_jp[] = {1, 3, 5, 7, 0, 2, 4, 6, 1, 3, 5, 7}; + long ringpix; + + fc = nestpix >> (2 * factor); + hpix_pix2xy(ctab, nestpix & (nside * nside - 1), &x, &y); + jr = (hpix_jr[fc] * nside) - x - y - 1; + if (jr < nside) { + nr = jr; + n_before = 2 * nr * (nr - 1); + kshift = 0; + } else if (jr > (3 * nside)) { + nr = 4 * nside - jr; + n_before = npix - 2 * (nr + 1) * nr; + kshift = 0; + } else { + nr = nside; + n_before = ncap + (jr - nside) * 4 * nside; + kshift = (jr - nside) & 1; + } + jp = (hpix_jp[fc] * nr + x - y + 1 + kshift) / 2; + if (jp > 4 * nside) { + jp -= 4 * nside; + } else { + if (jp < 1) { + jp += 4 * nside; + } + } + ringpix = n_before + jp - 1; + return ringpix; +} + +long hpix_degrade_nest( + long degrade_levels, + long in_pix +) { + return in_pix >> (2 * degrade_levels); +} + +long hpix_degrade_ring( + long in_nside, + long in_factor, + long in_npix, + long in_ncap, + long out_nside, + long out_factor, + long out_npix, + long out_ncap, + __global long const * utab, + __global long const * ctab, + long degrade_levels, + long in_pix +) { + long in_nest = hpix_ring2nest(in_nside, in_factor, in_npix, in_ncap, utab, in_pix); + long out_nest = hpix_degrade_nest(degrade_levels, in_nest); + return hpix_nest2ring(out_nside, out_factor, out_npix, out_ncap, ctab, out_nest); +} + +long hpix_upgrade_nest( + long upgrade_levels, + long in_pix +) { + return in_pix << (2 * upgrade_levels); +} + +long hpix_upgrade_ring( + long in_nside, + long in_factor, + long in_npix, + long in_ncap, + long out_nside, + long out_factor, + long out_npix, + long out_ncap, + __global long const * utab, + __global long const * ctab, + long upgrade_levels, + long in_pix +) { + long in_nest = hpix_ring2nest(in_nside, in_factor, in_npix, in_ncap, utab, in_pix); + long out_nest = hpix_upgrade_nest(upgrade_levels, in_nest); + return hpix_nest2ring(out_nside, out_factor, out_npix, out_ncap, ctab, out_nest); +} + +// Kernels + +__kernel void pixels_healpix_nest( + int n_det, + long n_sample, + long first_sample, + long nside, + long factor, + long n_pix_submap, + __global long const * utab, + __global int const * quat_index, + __global double const * quats, + __global int const * pixel_index, + __global long * pixels, + __global unsigned char * hsub, + __global unsigned char const * shared_flags, + unsigned char shared_flag_mask, + unsigned char use_flags, + unsigned char compute_submaps +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + const double zaxis[3] = {0.0, 0.0, 1.0}; + int p_indx = pixel_index[idet]; + int q_indx = quat_index[idet]; + double dir[3]; + double z; + double rtz; + double phi; + int region; + size_t qoff = (q_indx * 4 * n_sample) + 4 * isamp; + size_t poff = p_indx * n_sample + isamp; + long sub_map; + + // Copy to private variable in order to pass to subroutines. + double temp_quat[4]; + temp_quat[0] = quats[qoff]; + temp_quat[1] = quats[qoff + 1]; + temp_quat[2] = quats[qoff + 2]; + temp_quat[3] = quats[qoff + 3]; + + hpix_qa_rotate(temp_quat, zaxis, dir); + hpix_vec2zphi(dir, &phi, ®ion, &z, &rtz); + pixels[poff] = hpix_zphi2nest(nside, factor, utab, phi, region, z, rtz); + if (use_flags && ((shared_flags[isamp] & shared_flag_mask) != 0)) { + pixels[poff] = -1; + } else { + if (compute_submaps) { + sub_map = (long)(pixels[poff] / n_pix_submap); + hsub[sub_map] = 1; + } + } + + return; +} + +// Note: Although utab is not needed for ang to ring pix, +// we keep it in the argument list to simplify the calling +// code. + +__kernel void pixels_healpix_ring( + int n_det, + long n_sample, + long first_sample, + long nside, + long factor, + long n_pix_submap, + __global long const * utab, + __global int const * quat_index, + __global double const * quats, + __global int const * pixel_index, + __global long * pixels, + __global unsigned char * hsub, + __global unsigned char const * shared_flags, + unsigned char shared_flag_mask, + unsigned char use_flags, + unsigned char compute_submaps +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + const double zaxis[3] = {0.0, 0.0, 1.0}; + int p_indx = pixel_index[idet]; + int q_indx = quat_index[idet]; + double dir[3]; + double z; + double rtz; + double phi; + int region; + size_t qoff = (q_indx * 4 * n_sample) + 4 * isamp; + size_t poff = p_indx * n_sample + isamp; + long sub_map; + + // Copy to private variable in order to pass to subroutines. + double temp_quat[4]; + temp_quat[0] = quats[qoff]; + temp_quat[1] = quats[qoff + 1]; + temp_quat[2] = quats[qoff + 2]; + temp_quat[3] = quats[qoff + 3]; + + hpix_qa_rotate(temp_quat, zaxis, dir); + hpix_vec2zphi(dir, &phi, ®ion, &z, &rtz); + pixels[poff] = hpix_zphi2ring(nside, factor, phi, region, z, rtz); + if (use_flags && ((shared_flags[isamp] & shared_flag_mask) != 0)) { + pixels[poff] = -1; + } else { + if (compute_submaps) { + sub_map = (long)(pixels[poff] / n_pix_submap); + hsub[sub_map] = 1; + } + } + + return; +} diff --git a/src/toast/ops/pixels_healpix/kernels_opencl.py b/src/toast/ops/pixels_healpix/kernels_opencl.py new file mode 100644 index 000000000..e018731dc --- /dev/null +++ b/src/toast/ops/pixels_healpix/kernels_opencl.py @@ -0,0 +1,170 @@ +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import healpy as hp +import numpy as np +import pyopencl as cl + +from ... import qarray as qa +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + clear_kernel_deps, +) + + +utab = None +ctab = None + +# Initialize these constant arrays on host +if utab is None: + utab = np.zeros(0x100, dtype=np.int64) + ctab = np.zeros(0x100, dtype=np.int64) + for m in range(0x100): + utab[m] = ( + (m & 0x1) + | ((m & 0x2) << 1) + | ((m & 0x4) << 2) + | ((m & 0x8) << 3) + | ((m & 0x10) << 4) + | ((m & 0x20) << 5) + | ((m & 0x40) << 6) + | ((m & 0x80) << 7) + ) + ctab[m] = ( + (m & 0x1) + | ((m & 0x2) << 7) + | ((m & 0x4) >> 1) + | ((m & 0x8) << 6) + | ((m & 0x10) >> 2) + | ((m & 0x20) << 5) + | ((m & 0x40) >> 3) + | ((m & 0x80) << 4) + ) + + +@kernel(impl=ImplementationType.OPENCL, name="pixels_healpix") +def pixels_healpix_opencl( + quat_index, + quats, + shared_flags, + shared_flag_mask, + pixel_index, + pixels, + intervals, + hit_submaps, + n_pix_submap, + nside, + nest, + compute_submaps, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + if len(shared_flags) == quats.shape[1]: + use_flags = np.uint8(1) + else: + use_flags = np.uint8(0) + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + # Make sure that our small helper arrays are staged. These are persistent + # and we do not delete them until the program terminates. + if not ocl.mem_present(utab, name="utab", device_type=devtype): + dev_utab = ocl.mem_create(utab, name="utab", device_type=devtype) + ocl.mem_update_device(utab, name="utab", device_type=devtype) + else: + dev_utab = ocl.mem(utab, name="utab", device_type=devtype) + + if not ocl.mem_present(ctab, name="ctab", device_type=devtype): + dev_ctab = ocl.mem_create(ctab, name="ctab", device_type=devtype) + ocl.mem_update_device(ctab, name="ctab", device_type=devtype) + else: + dev_ctab = ocl.mem(ctab, name="ctab", device_type=devtype) + + factor = 0 + while nside != 1 << factor: + factor += 1 + + if nest: + pixels_healpix = ocl.get_or_build_kernel( + "pixels_healpix", + "pixels_healpix_nest", + device_type=devtype, + source=program_file, + ) + else: + pixels_healpix = ocl.get_or_build_kernel( + "pixels_healpix", + "pixels_healpix_ring", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_pixels = ocl.mem(pixels, device_type=devtype) + dev_quats = ocl.mem(quats, device_type=devtype) + if compute_submaps: + dev_hit_submaps = ocl.mem(hit_submaps, device_type=devtype) + else: + dev_hit_submaps = ocl.mem_null(hit_submaps, device_type=devtype) + if use_flags: + dev_flags = ocl.mem(shared_flags, device_type=devtype) + else: + dev_flags = ocl.mem_null(shared_flags, device_type=devtype) + + # Allocate temporary device arrays + dev_quat_index = ocl.mem_to_device(quat_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_quat_index.events) + dev_pixel_index = ocl.mem_to_device(pixel_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_pixel_index.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + print(f"PIXHPX: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = len(pixel_index) + n_samp = quats.shape[1] + for intr in intervals: + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = pixels_healpix( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + np.int64(nside), + np.int64(factor), + np.int64(n_pix_submap), + dev_utab.data, + dev_quat_index.data, + dev_quats.data, + dev_pixel_index.data, + dev_pixels.data, + dev_hit_submaps.data, + dev_flags.data, + np.uint8(shared_flag_mask), + use_flags, + np.uint8(compute_submaps), + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + print(f"PIXHPX: {obs_name} export wait_for = {wait_for}", flush=True) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(quat_index, device_type=devtype) + ocl.mem_remove(pixel_index, device_type=devtype) diff --git a/src/toast/ops/pixels_healpix/pixels_healpix.py b/src/toast/ops/pixels_healpix/pixels_healpix.py index 0cdd248dc..eb8b920c2 100644 --- a/src/toast/ops/pixels_healpix/pixels_healpix.py +++ b/src/toast/ops/pixels_healpix/pixels_healpix.py @@ -5,7 +5,15 @@ import numpy as np import traitlets -from ...accelerator import ImplementationType +from ...accelerator import ( + ImplementationType, + accel_data_create, + accel_data_delete, + accel_data_update_device, + accel_data_update_host, + accel_data_present, + accel_wait, +) from ...observation import default_values as defaults from ...pixels import PixelDistribution from ...timing import function_timer @@ -134,7 +142,7 @@ def _set_hpix(self, nside, nside_submap): self._n_pix = 12 * nside**2 self._n_pix_submap = 12 * nside_submap**2 self._n_submap = (nside // nside_submap) ** 2 - self._local_submaps = None + self._local_submaps = np.zeros(self._n_submap, dtype=np.uint8) @function_timer def _exec(self, data, detectors=None, use_accel=None, **kwargs): @@ -147,8 +155,16 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): if self.detector_pointing is None: raise RuntimeError("The detector_pointing trait must be set") - if self._local_submaps is None and self.create_dist is not None: - self._local_submaps = np.zeros(self._n_submap, dtype=np.uint8) + if self.create_dist is not None: + # We are computing the pixel distribution + if use_accel: + if not accel_data_present(self._local_submaps, name="hit_submaps"): + # First call, create the data on the device + accel_data_create( + self._local_submaps, + name="hit_submaps", + zero_out=True, + ) # Expand detector pointing quats_name = self.detector_pointing.quats @@ -159,7 +175,9 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): view = self.detector_pointing.view # Expand detector pointing - self.detector_pointing.apply(data, detectors=detectors, use_accel=use_accel) + self.detector_pointing.apply( + data, detectors=detectors, use_accel=use_accel, **kwargs + ) for ob in data.obs: # Get the detectors we are using for this observation @@ -205,10 +223,6 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): accel=use_accel, ) - hit_submaps = self._local_submaps - if hit_submaps is None: - hit_submaps = np.zeros(self._n_submap, dtype=np.uint8) - quat_indx = ob.detdata[quats_name].indices(dets) pix_indx = ob.detdata[self.pixels].indices(dets) @@ -223,7 +237,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): if ob.detdata[self.pixels].accel_in_use(): # The data is on the accelerator- copy back to host for # this calculation. This could eventually be a kernel. - ob.detdata[self.pixels].accel_update_host() + events = ob.detdata[self.pixels].accel_update_host() + accel_wait(events) restore_dev = True for det in ob.select_local_detectors( detectors, flagmask=self.detector_pointing.det_mask @@ -235,7 +250,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): // self._n_pix_submap ] = 1 if restore_dev: - ob.detdata[self.pixels].accel_update_device() + events = ob.detdata[self.pixels].accel_update_device() + accel_wait(events) if data.comm.group_rank == 0: msg = ( @@ -260,22 +276,23 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): pix_indx, ob.detdata[self.pixels].data, ob.intervals[self.view].data, - hit_submaps, + self._local_submaps, self._n_pix_submap, self.nside, self.nest, + (self.create_dist is not None), impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) - if self._local_submaps is not None: - self._local_submaps[:] |= hit_submaps - - return - def _finalize(self, data, use_accel=None, **kwargs): if self.create_dist is not None: - submaps = None + _, use_accel = self.select_kernels(use_accel=use_accel) + if use_accel: + # The locally hit submaps is on the device, copy back + accel_data_update_host(self._local_submaps, name="hit_submaps") if self.single_precision: submaps = np.arange(self._n_submap, dtype=np.int32)[ self._local_submaps == 1 @@ -315,6 +332,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): diff --git a/src/toast/ops/pixels_wcs.py b/src/toast/ops/pixels_wcs.py index bcd0cd922..cc47a19c1 100644 --- a/src/toast/ops/pixels_wcs.py +++ b/src/toast/ops/pixels_wcs.py @@ -11,6 +11,7 @@ from astropy.wcs import WCS from .. import qarray as qa +from ..accelerator import accel_wait from ..instrument_coords import quat_to_xieta from ..mpi import MPI from ..observation import default_values as defaults @@ -510,7 +511,8 @@ def _exec(self, data, detectors=None, **kwargs): if ob.detdata[self.pixels].accel_in_use(): # The data is on the accelerator- copy back to host for # this calculation. This could eventually be a kernel. - ob.detdata[self.pixels].accel_update_host() + events = ob.detdata[self.pixels].accel_update_host() + accel_wait(events) restore_dev = True for det in dets: for vslice in view_slices: @@ -520,7 +522,8 @@ def _exec(self, data, detectors=None, **kwargs): // self._n_pix_submap ] = 1 if restore_dev: - ob.detdata[self.pixels].accel_update_device() + events = ob.detdata[self.pixels].accel_update_device() + accel_wait(events) if data.comm.group_rank == 0: msg = ( diff --git a/src/toast/ops/pointing_detector/CMakeLists.txt b/src/toast/ops/pointing_detector/CMakeLists.txt index 09e4fc5ac..3a8f21fcf 100644 --- a/src/toast/ops/pointing_detector/CMakeLists.txt +++ b/src/toast/ops/pointing_detector/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/ops/pointing_detector ) diff --git a/src/toast/ops/pointing_detector/kernels.py b/src/toast/ops/pointing_detector/kernels.py index dc743d42a..4804ab3ba 100644 --- a/src/toast/ops/pointing_detector/kernels.py +++ b/src/toast/ops/pointing_detector/kernels.py @@ -6,12 +6,15 @@ from ... import qarray as qa from ..._libtoast import pointing_detector as libtoast_pointing_detector -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import pointing_detector_numpy if use_accel_jax: from .kernels_jax import pointing_detector_jax +if use_accel_opencl: + from .kernels_opencl import pointing_detector_opencl + @kernel(impl=ImplementationType.DEFAULT) def pointing_detector( @@ -23,6 +26,7 @@ def pointing_detector( shared_flags, shared_flag_mask, use_accel=False, + **kwargs, ): """Kernel for computing detector quaternion pointing. @@ -54,5 +58,5 @@ def pointing_detector( @kernel(impl=ImplementationType.COMPILED, name="pointing_detector") -def pointing_detector_compiled(*args, use_accel=False): +def pointing_detector_compiled(*args, use_accel=False, **kwargs): return libtoast_pointing_detector(*args, use_accel) diff --git a/src/toast/ops/pointing_detector/kernels_jax.py b/src/toast/ops/pointing_detector/kernels_jax.py index c594c5d8b..7b1383a57 100644 --- a/src/toast/ops/pointing_detector/kernels_jax.py +++ b/src/toast/ops/pointing_detector/kernels_jax.py @@ -123,6 +123,7 @@ def pointing_detector_jax( shared_flags, shared_flag_mask, use_accel=False, + **kwargs, ): """ Args: diff --git a/src/toast/ops/pointing_detector/kernels_numpy.py b/src/toast/ops/pointing_detector/kernels_numpy.py index 60f281d45..48e205dc6 100644 --- a/src/toast/ops/pointing_detector/kernels_numpy.py +++ b/src/toast/ops/pointing_detector/kernels_numpy.py @@ -18,6 +18,7 @@ def pointing_detector_numpy( shared_flags, shared_flag_mask, use_accel=False, + **kwargs, ): for idet in range(len(quat_index)): qidx = quat_index[idet] diff --git a/src/toast/ops/pointing_detector/kernels_opencl.cl b/src/toast/ops/pointing_detector/kernels_opencl.cl new file mode 100644 index 000000000..fc1d32325 --- /dev/null +++ b/src/toast/ops/pointing_detector/kernels_opencl.cl @@ -0,0 +1,68 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + +void pointing_detector_qa_mult(double const * p, double const * q, double * r) { + r[0] = p[0] * q[3] + p[1] * q[2] - + p[2] * q[1] + p[3] * q[0]; + r[1] = -p[0] * q[2] + p[1] * q[3] + + p[2] * q[0] + p[3] * q[1]; + r[2] = p[0] * q[1] - p[1] * q[0] + + p[2] * q[3] + p[3] * q[2]; + r[3] = -p[0] * q[0] - p[1] * q[1] - + p[2] * q[2] + p[3] * q[3]; + return; +} + +__kernel void pointing_detector( + int n_det, + long n_sample, + long first_sample, + __global double const * focalplane, + __global double const * boresight, + __global int const * quat_index, + __global double * quats, + __global unsigned char const * shared_flags, + unsigned char shared_flag_mask, + unsigned char use_flags +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + int qidx = quat_index[idet]; + + // Copy to private variables in order to pass to subroutines. + double temp_bore[4]; + double temp_fp[4]; + double temp_quat[4]; + + unsigned char check = 0; + if (use_flags != 0) { + check = shared_flags[isamp] & shared_flag_mask; + } + + if (check == 0) { + temp_bore[0] = boresight[4 * isamp]; + temp_bore[1] = boresight[4 * isamp + 1]; + temp_bore[2] = boresight[4 * isamp + 2]; + temp_bore[3] = boresight[4 * isamp + 3]; + } else { + temp_bore[0] = 0.0; + temp_bore[1] = 0.0; + temp_bore[2] = 0.0; + temp_bore[3] = 1.0; + } + temp_fp[0] = focalplane[4 * idet]; + temp_fp[1] = focalplane[4 * idet + 1]; + temp_fp[2] = focalplane[4 * idet + 2]; + temp_fp[3] = focalplane[4 * idet + 3]; + + pointing_detector_qa_mult(temp_bore, temp_fp, temp_quat); + + quats[(qidx * 4 * n_sample) + 4 * isamp] = temp_quat[0]; + quats[(qidx * 4 * n_sample) + 4 * isamp + 1] = temp_quat[1]; + quats[(qidx * 4 * n_sample) + 4 * isamp + 2] = temp_quat[2]; + quats[(qidx * 4 * n_sample) + 4 * isamp + 3] = temp_quat[3]; +} + diff --git a/src/toast/ops/pointing_detector/kernels_opencl.py b/src/toast/ops/pointing_detector/kernels_opencl.py new file mode 100644 index 000000000..400b92cad --- /dev/null +++ b/src/toast/ops/pointing_detector/kernels_opencl.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import numpy as np +import pyopencl as cl + +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + clear_kernel_deps, +) + + +@kernel(impl=ImplementationType.OPENCL, name="pointing_detector") +def pointing_detector_opencl( + focalplane, + boresight, + quat_index, + quats, + intervals, + shared_flags, + shared_flag_mask, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + if len(shared_flags) == len(boresight): + use_flags = np.uint8(1) + else: + use_flags = np.uint8(0) + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + # Get our kernel + pointing_detector = ocl.get_or_build_kernel( + "pointing_detector", + "pointing_detector", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_boresight = ocl.mem(boresight, device_type=devtype) + dev_quats = ocl.mem(quats, device_type=devtype) + if use_flags: + dev_flags = ocl.mem(shared_flags, device_type=devtype) + else: + dev_flags = ocl.mem_null(shared_flags, device_type=devtype) + + # Allocate temporary arrays and copy to device + dev_quat_index = ocl.mem_to_device(quat_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_quat_index.events) + dev_fp = ocl.mem_to_device(focalplane, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_fp.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"PNTDET: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = len(quat_index) + n_samp = quats.shape[1] + for intr in intervals: + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = pointing_detector( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_fp.data, + dev_boresight.data, + dev_quat_index.data, + dev_quats.data, + dev_flags.data, + np.uint8(shared_flag_mask), + use_flags, + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(quat_index, device_type=devtype) + ocl.mem_remove(focalplane, device_type=devtype) diff --git a/src/toast/ops/pointing_detector/pointing_detector.py b/src/toast/ops/pointing_detector/pointing_detector.py index a4be394b1..aad9cd1c3 100644 --- a/src/toast/ops/pointing_detector/pointing_detector.py +++ b/src/toast/ops/pointing_detector/pointing_detector.py @@ -7,7 +7,7 @@ import traitlets from ... import qarray as qa -from ...accelerator import ImplementationType +from ...accelerator import ImplementationType, accel_wait from ...observation import default_values as defaults from ...timing import function_timer from ...traits import Bool, Int, Quantity, Unicode, UseEnum, trait_docs @@ -182,12 +182,15 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): if not ob.shared.accel_exists(bore_name): # Does not even exist yet on the device ob.shared.accel_create(bore_name) - ob.shared.accel_update_device(bore_name) + events = ob.shared.accel_update_device(bore_name) + accel_wait(events) else: if ob.shared.accel_in_use(bore_name): # Back to host - ob.shared.accel_update_host(bore_name) + events = ob.shared.accel_update_host(bore_name) + accel_wait(events) + kret = dict() for ob in data.obs: # Get the detectors we are using for this observation dets = ob.select_local_detectors(detectors, flagmask=self.det_mask) @@ -284,9 +287,11 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): self.shared_flag_mask, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) - return + return kret def _finalize(self, data, **kwargs): return @@ -318,6 +323,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): diff --git a/src/toast/ops/polyfilter/kernels.py b/src/toast/ops/polyfilter/kernels.py index 5333907a0..f3a8fd759 100644 --- a/src/toast/ops/polyfilter/kernels.py +++ b/src/toast/ops/polyfilter/kernels.py @@ -21,6 +21,7 @@ def filter_polynomial( starts, stops, use_accel=False, + **kwargs, ): """Kernel to fit and subtract a polynomial from one or more signals. @@ -53,6 +54,7 @@ def filter_poly2D( masks, coeff, use_accel=False, + **kwargs, ): """Kernel for solving 2D polynomial coefficients at each sample. @@ -78,10 +80,10 @@ def filter_poly2D( @kernel(impl=ImplementationType.COMPILED, name="filter_polynomial") -def filter_polynomial_compiled(*args, use_accel=False): +def filter_polynomial_compiled(*args, use_accel=False, **kwargs): return libtoast_filter_polynomial(*args, use_accel) @kernel(impl=ImplementationType.COMPILED, name="filter_poly2D") -def filter_poly2D_compiled(*args, use_accel=False): +def filter_poly2D_compiled(*args, use_accel=False, **kwargs): return libtoast_filter_poly2D(*args, use_accel) diff --git a/src/toast/ops/polyfilter/kernels_jax.py b/src/toast/ops/polyfilter/kernels_jax.py index aca07e2c2..7990c4233 100644 --- a/src/toast/ops/polyfilter/kernels_jax.py +++ b/src/toast/ops/polyfilter/kernels_jax.py @@ -85,7 +85,9 @@ def filter_poly2D_coeffs(ngroup, det_groups, templates, signals, masks): @kernel(impl=ImplementationType.JAX, name="filter_poly2D") -def filter_poly2D_jax(det_groups, templates, signals, masks, coeff, use_accel): +def filter_poly2D_jax( + det_groups, templates, signals, masks, coeff, use_accel, **kwargs +): """ Solves for 2D polynomial coefficients at each sample. @@ -171,7 +173,9 @@ def filter_polynomial_interval(flags_interval, signals_interval, order): @kernel(impl=ImplementationType.JAX, name="filter_polynomial") -def filter_polynomial_jax(order, flags, signals_list, starts, stops, use_accel): +def filter_polynomial_jax( + order, flags, signals_list, starts, stops, use_accel, **kwargs +): """ Fit and subtract a polynomial from one or more signals. diff --git a/src/toast/ops/polyfilter/kernels_numpy.py b/src/toast/ops/polyfilter/kernels_numpy.py index a15836b09..9622db5e1 100644 --- a/src/toast/ops/polyfilter/kernels_numpy.py +++ b/src/toast/ops/polyfilter/kernels_numpy.py @@ -8,7 +8,9 @@ @kernel(impl=ImplementationType.NUMPY, name="filter_polynomial") -def filter_polynomial_numpy(order, flags, signals_list, starts, stops, use_accel=False): +def filter_polynomial_numpy( + order, flags, signals_list, starts, stops, use_accel=False, **kwargs +): # validate order if order < 0: return @@ -84,7 +86,9 @@ def filter_polynomial_numpy(order, flags, signals_list, starts, stops, use_accel @kernel(impl=ImplementationType.NUMPY, name="filter_poly2D") -def filter_poly2D_numpy(det_groups, templates, signals, masks, coeff, use_accel=False): +def filter_poly2D_numpy( + det_groups, templates, signals, masks, coeff, use_accel=False, **kwargs +): ngroup = coeff.shape[1] nsample = signals.shape[0] for isample in range(nsample): diff --git a/src/toast/ops/polyfilter/polyfilter.py b/src/toast/ops/polyfilter/polyfilter.py index ff3078ce7..7c9dd7e4b 100644 --- a/src/toast/ops/polyfilter/polyfilter.py +++ b/src/toast/ops/polyfilter/polyfilter.py @@ -337,6 +337,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): coeff, impl=implementation, use_accel=use_accel, + **kwargs, ) gt.stop("Poly2D: Solve templates") @@ -584,6 +585,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): local_stops, impl=implementation, use_accel=use_accel, + **kwargs, ) if not in_place: for fdet, x in zip(filter_dets, signals): @@ -601,6 +603,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): local_stops, impl=implementation, use_accel=use_accel, + **kwargs, ) if not in_place: for fdet, x in zip(filter_dets, signals): diff --git a/src/toast/ops/scan_map/CMakeLists.txt b/src/toast/ops/scan_map/CMakeLists.txt index 6826bbb66..4df52ed57 100644 --- a/src/toast/ops/scan_map/CMakeLists.txt +++ b/src/toast/ops/scan_map/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/ops/scan_map ) diff --git a/src/toast/ops/scan_map/kernels.py b/src/toast/ops/scan_map/kernels.py index a9a720e8a..fe068e102 100644 --- a/src/toast/ops/scan_map/kernels.py +++ b/src/toast/ops/scan_map/kernels.py @@ -10,12 +10,15 @@ from ..._libtoast import ops_scan_map_float64 as libtoast_scan_map_float64 from ..._libtoast import ops_scan_map_int32 as libtoast_scan_map_int32 from ..._libtoast import ops_scan_map_int64 as libtoast_scan_map_int64 -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import scan_map_numpy if use_accel_jax: from .kernels_jax import scan_map_jax +if use_accel_opencl: + from .kernels_opencl import scan_map_opencl + @kernel(impl=ImplementationType.DEFAULT) def scan_map( @@ -34,6 +37,7 @@ def scan_map( should_subtract, should_scale, use_accel=False, + **kwargs, ): """Kernel for scanning a map into timestreams. @@ -83,7 +87,7 @@ def scan_map( @kernel(impl=ImplementationType.COMPILED, name="scan_map") -def scan_map_compiled(*args, use_accel=False): +def scan_map_compiled(*args, use_accel=False, **kwargs): return libtoast_scan_map(*args, use_accel) diff --git a/src/toast/ops/scan_map/kernels_jax.py b/src/toast/ops/scan_map/kernels_jax.py index 030a8c690..93f67e5bb 100644 --- a/src/toast/ops/scan_map/kernels_jax.py +++ b/src/toast/ops/scan_map/kernels_jax.py @@ -221,6 +221,7 @@ def scan_map_jax( should_subtract=False, should_scale=False, use_accel=False, + **kwargs, ): """ Kernel for scanning a map into timestreams. diff --git a/src/toast/ops/scan_map/kernels_numpy.py b/src/toast/ops/scan_map/kernels_numpy.py index e4dce24d5..c6c1edeb3 100644 --- a/src/toast/ops/scan_map/kernels_numpy.py +++ b/src/toast/ops/scan_map/kernels_numpy.py @@ -24,6 +24,7 @@ def scan_map_numpy( should_subtract=False, should_scale=False, use_accel=False, + **kwargs, ): nmap = weights.shape[-1] local_map = mapdata.reshape((-1, nmap)) diff --git a/src/toast/ops/scan_map/kernels_opencl.cl b/src/toast/ops/scan_map/kernels_opencl.cl new file mode 100644 index 000000000..34cfadb35 --- /dev/null +++ b/src/toast/ops/scan_map/kernels_opencl.cl @@ -0,0 +1,74 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + +// Kernels + + +// FIXME: macro-fy this across all types + +__kernel void scan_map_d_to_d( + int n_det, + long n_sample, + long first_sample, + __global int const * pixels_index, + __global long const * pixels, + __global int const * weight_index, + __global double const * weights, + __global int const * det_data_index, + __global double * det_data, + __global double const * mapdata, + __global long const * global2local, + long nnz, + long npix_submap, + double data_scale, + unsigned char should_zero, + unsigned char should_subtract, + unsigned char should_scale +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + int d_indx = det_data_index[idet]; + int p_indx = pixels_index[idet]; + int w_indx = weight_index[idet]; + + size_t woff = nnz * (w_indx * n_sample + isamp); + size_t poff = p_indx * n_sample + isamp; + size_t doff = d_indx * n_sample + isamp; + + long global_submap; + long local_submap_pix; + long local_submap; + long local_pix; + + double tod_val = 0.0; + + if (pixels[poff] >= 0) { + global_submap = (long)(pixels[poff] / npix_submap); + local_submap_pix = pixels[poff] - global_submap * npix_submap; + local_submap = global2local[global_submap]; + local_pix = local_submap * npix_submap + local_submap_pix; + + for (long i = 0; i < nnz; i++) { + tod_val += weights[woff + i] * mapdata[nnz * local_pix + i]; + } + tod_val *= data_scale; + + if (should_zero) { + det_data[doff] = 0; + } + if (should_subtract) { + det_data[doff] -= tod_val; + } else if (should_scale) { + det_data[doff] *= tod_val; + } else { + det_data[doff] += tod_val; + } + } + return; +} + + + diff --git a/src/toast/ops/scan_map/kernels_opencl.py b/src/toast/ops/scan_map/kernels_opencl.py new file mode 100644 index 000000000..cc9a4a9b0 --- /dev/null +++ b/src/toast/ops/scan_map/kernels_opencl.py @@ -0,0 +1,139 @@ +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import numpy as np +import pyopencl as cl + +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + clear_kernel_deps, +) + + +@kernel(impl=ImplementationType.OPENCL, name="scan_map") +def scan_map_opencl( + global2local, + n_pix_submap, + mapdata, + det_data, + det_data_index, + pixels, + pixels_index, + weights, + weight_index, + intervals, + data_scale=1.0, + should_zero=False, + should_subtract=False, + should_scale=False, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + # Select the kernel we will use based on input datatypes + if mapdata.dtype.char == "d": + kname = "scan_map_d_to_" + elif mapdata.dtype.char == "f": + kname = "scan_map_f_to_" + elif mapdata.dtype.char == "i": + kname = "scan_map_i_to_" + elif mapdata.dtype.char == "l": + kname = "scan_map_l_to_" + else: + msg = f"OpenCL version of scan_map does not support map " + msg += f"dtype '{mapdata.dtype.char}'" + raise NotImplementedError(msg) + + if det_data.dtype.char == "d": + kname += "d" + elif mapdata.dtype.char == "f": + kname += "f" + elif mapdata.dtype.char == "i": + kname += "i" + elif mapdata.dtype.char == "l": + kname += "l" + else: + msg = f"OpenCL version of scan_map does not support det_data " + msg += f"dtype '{det_data.dtype.char}'" + raise NotImplementedError(msg) + + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + scan_map_kernel = ocl.get_or_build_kernel( + "scan_map", + kname, + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_global2local = ocl.mem(global2local, device_type=devtype) + dev_mapdata = ocl.mem(mapdata, device_type=devtype) + dev_pixels = ocl.mem(pixels, device_type=devtype) + dev_weights = ocl.mem(weights, device_type=devtype) + dev_det_data = ocl.mem(det_data, device_type=devtype) + + # Allocate temporary device arrays + dev_pixels_index = ocl.mem_to_device(pixels_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_pixels_index.events) + + dev_weight_index = ocl.mem_to_device(weight_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_weight_index.events) + + dev_det_data_index = ocl.mem_to_device( + det_data_index, device_type=devtype, async_=True + ) + add_kernel_deps(state, obs_name, dev_det_data_index.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"SCANMAP: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = len(det_data_index) + n_samp = weights.shape[1] + nnz = weights.shape[2] + for intr in intervals: + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = scan_map_kernel( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_pixels_index.data, + dev_pixels.data, + dev_weight_index.data, + dev_weights.data, + dev_det_data_index.data, + dev_det_data.data, + dev_mapdata.data, + dev_global2local.data, + np.int64(nnz), + np.int64(n_pix_submap), + np.float64(data_scale), + np.uint8(should_zero), + np.uint8(should_subtract), + np.uint8(should_scale), + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(pixels_index, device_type=devtype) + ocl.mem_remove(weight_index, device_type=devtype) + ocl.mem_remove(det_data_index, device_type=devtype) diff --git a/src/toast/ops/scan_map/scan_map.py b/src/toast/ops/scan_map/scan_map.py index 98791c42d..04bbb96ce 100644 --- a/src/toast/ops/scan_map/scan_map.py +++ b/src/toast/ops/scan_map/scan_map.py @@ -175,6 +175,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): False, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) return @@ -205,6 +207,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): @@ -487,6 +490,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): True, impl=implementation, use_accel=use_accel, + **kwargs, ) return @@ -510,5 +514,14 @@ def _provides(self): prov = {"meta": list(), "shared": list(), "detdata": [self.det_data]} return prov + def _implementations(self): + return [ + ImplementationType.DEFAULT, + ImplementationType.COMPILED, + ImplementationType.NUMPY, + ImplementationType.JAX, + ImplementationType.OPENCL, + ] + def _supports_accel(self): return True diff --git a/src/toast/ops/stokes_weights/CMakeLists.txt b/src/toast/ops/stokes_weights/CMakeLists.txt index 1f476a6e1..0e2d87d6a 100644 --- a/src/toast/ops/stokes_weights/CMakeLists.txt +++ b/src/toast/ops/stokes_weights/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/ops/stokes_weights ) diff --git a/src/toast/ops/stokes_weights/kernels.py b/src/toast/ops/stokes_weights/kernels.py index 69f85c433..41184d5d4 100644 --- a/src/toast/ops/stokes_weights/kernels.py +++ b/src/toast/ops/stokes_weights/kernels.py @@ -7,20 +7,23 @@ from ... import qarray as qa from ..._libtoast import stokes_weights_I as libtoast_stokes_weights_I from ..._libtoast import stokes_weights_IQU as libtoast_stokes_weights_IQU -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import stokes_weights_I_numpy, stokes_weights_IQU_numpy if use_accel_jax: from .kernels_jax import stokes_weights_I_jax, stokes_weights_IQU_jax +if use_accel_opencl: + from .kernels_opencl import stokes_weights_I_opencl, stokes_weights_IQU_opencl + @kernel(impl=ImplementationType.COMPILED, name="stokes_weights_I") -def stokes_weights_I_compiled(*args, use_accel=False): +def stokes_weights_I_compiled(*args, use_accel=False, **kwargs): return libtoast_stokes_weights_I(*args, use_accel) @kernel(impl=ImplementationType.COMPILED, name="stokes_weights_IQU") -def stokes_weights_IQU_compiled(*args, use_accel=False): +def stokes_weights_IQU_compiled(*args, use_accel=False, **kwargs): return libtoast_stokes_weights_IQU(*args, use_accel) @@ -31,6 +34,7 @@ def stokes_weights_I( intervals, cal, use_accel=False, + **kwargs, ): """Kernel for computing trivial intensity-only Stokes pointing weights. @@ -67,6 +71,7 @@ def stokes_weights_IQU( cal, IAU, use_accel=False, + **kwargs, ): """Kernel for computing the I/Q/U Stokes pointing weights. diff --git a/src/toast/ops/stokes_weights/kernels_jax.py b/src/toast/ops/stokes_weights/kernels_jax.py index 185b65bf9..334c107ed 100644 --- a/src/toast/ops/stokes_weights/kernels_jax.py +++ b/src/toast/ops/stokes_weights/kernels_jax.py @@ -180,6 +180,7 @@ def stokes_weights_IQU_jax( cal, IAU, use_accel, + **kwargs, ): """ Compute the Stokes weights for the "IQU" mode. @@ -311,7 +312,7 @@ def stokes_weights_I_interval( @kernel(impl=ImplementationType.JAX, name="stokes_weights_I") -def stokes_weights_I_jax(weight_index, weights, intervals, cal, use_accel): +def stokes_weights_I_jax(weight_index, weights, intervals, cal, use_accel, **kwargs): """ Compute the Stokes weights for the "I" mode. diff --git a/src/toast/ops/stokes_weights/kernels_numpy.py b/src/toast/ops/stokes_weights/kernels_numpy.py index 431214a1e..fa5930591 100644 --- a/src/toast/ops/stokes_weights/kernels_numpy.py +++ b/src/toast/ops/stokes_weights/kernels_numpy.py @@ -21,8 +21,9 @@ def stokes_weights_IQU_numpy( cal, IAU, use_accel, + **kwargs, ): - if hwp is not None and len(hwp) == 0: + if hwp is not None and len(hwp) != len(quats): hwp = None if IAU: @@ -77,7 +78,7 @@ def stokes_weights_IQU_numpy( @kernel(impl=ImplementationType.NUMPY, name="stokes_weights_I") -def stokes_weights_I_numpy(weight_index, weights, intervals, cal, use_accel): +def stokes_weights_I_numpy(weight_index, weights, intervals, cal, use_accel, **kwargs): for idet in range(len(weight_index)): widx = weight_index[idet] for vw in intervals: diff --git a/src/toast/ops/stokes_weights/kernels_opencl.cl b/src/toast/ops/stokes_weights/kernels_opencl.cl new file mode 100644 index 000000000..0ebf0f8d3 --- /dev/null +++ b/src/toast/ops/stokes_weights/kernels_opencl.cl @@ -0,0 +1,179 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + +void stokes_weights_qa_rotate( + double const * q_in, + double const * v_in, + double * v_out +) { + // The input quaternion has already been normalized on the host. + + double xw = q_in[3] * q_in[0]; + double yw = q_in[3] * q_in[1]; + double zw = q_in[3] * q_in[2]; + double x2 = -q_in[0] * q_in[0]; + double xy = q_in[0] * q_in[1]; + double xz = q_in[0] * q_in[2]; + double y2 = -q_in[1] * q_in[1]; + double yz = q_in[1] * q_in[2]; + double z2 = -q_in[2] * q_in[2]; + + v_out[0] = 2 * ((y2 + z2) * v_in[0] + (xy - zw) * v_in[1] + + (yw + xz) * v_in[2]) + v_in[0]; + + v_out[1] = 2 * ((zw + xy) * v_in[0] + (x2 + z2) * v_in[1] + + (yz - xw) * v_in[2]) + v_in[1]; + + v_out[2] = 2 * ((xz - yw) * v_in[0] + (xw + yz) * v_in[1] + + (x2 + y2) * v_in[2]) + v_in[2]; + + return; +} + +void stokes_weights_alpha( + double const * quats, + double * alpha +) { + const double xaxis[3] = {1.0, 0.0, 0.0}; + const double zaxis[3] = {0.0, 0.0, 1.0}; + double vd[3]; + double vo[3]; + + stokes_weights_qa_rotate(quats, zaxis, vd); + stokes_weights_qa_rotate(quats, xaxis, vo); + + double ang_xy = atan2(vd[1], vd[0]); + double vm_x = vd[2] * cos(ang_xy); + double vm_y = vd[2] * sin(ang_xy); + double vm_z = - sqrt(1.0 - vd[2] * vd[2]); + + double alpha_y = ( + vd[0] * (vm_y * vo[2] - vm_z * vo[1]) - vd[1] * (vm_x * vo[2] - vm_z * vo[0]) + + vd[2] * (vm_x * vo[1] - vm_y * vo[0]) + ); + double alpha_x = (vm_x * vo[0] + vm_y * vo[1] + vm_z * vo[2]); + + (*alpha) = atan2(alpha_y, alpha_x); + return; +} + + +// Kernels + +__kernel void stokes_weights_IQU( + int n_det, + long n_sample, + long first_sample, + __global int const * quat_index, + __global double const * quats, + __global int const * weight_index, + __global double * weights, + __global double const * epsilon, + __global double const * gamma, + __global double const * cal, + double U_sign, + unsigned char IAU +) { + // NOTE: Flags are not needed here, since the quaternions + // have already had bad samples converted to null rotations. + + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = get_global_id(1); + + double eta = (1.0 - epsilon[idet]) / (1.0 + epsilon[idet]); + int q_indx = quat_index[idet]; + int w_indx = weight_index[idet]; + long qoff = (q_indx * 4 * n_sample) + 4 * isamp; + + // Copy to private variable in order to pass to subroutines. + double temp_quat[4]; + temp_quat[0] = quats[qoff]; + temp_quat[1] = quats[qoff + 1]; + temp_quat[2] = quats[qoff + 2]; + temp_quat[3] = quats[qoff + 3]; + + double alpha; + stokes_weights_alpha(temp_quat, &alpha); + + alpha *= 2.0; + double cang = cos(alpha); + double sang = sin(alpha); + + long woff = (w_indx * 3 * n_sample) + 3 * isamp; + weights[woff] = cal[idet]; + weights[woff + 1] = cang * eta * cal[idet]; + weights[woff + 2] = sang * eta * cal[idet] * U_sign; + + return; +} + +__kernel void stokes_weights_IQU_hwp( + int n_det, + long n_sample, + long first_sample, + __global int const * quat_index, + __global double const * quats, + __global int const * weight_index, + __global double * weights, + __global double const * hwp, + __global double const * epsilon, + __global double const * gamma, + __global double const * cal, + double U_sign, + unsigned char IAU +) { + // NOTE: Flags are not needed here, since the quaternions + // have already had bad samples converted to null rotations. + + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + double eta = (1.0 - epsilon[idet]) / (1.0 + epsilon[idet]); + int q_indx = quat_index[idet]; + int w_indx = weight_index[idet]; + long qoff = (q_indx * 4 * n_sample) + 4 * isamp; + + // Copy to private variable in order to pass to subroutines. + double temp_quat[4]; + temp_quat[0] = quats[qoff]; + temp_quat[1] = quats[qoff + 1]; + temp_quat[2] = quats[qoff + 2]; + temp_quat[3] = quats[qoff + 3]; + + double alpha; + stokes_weights_alpha(temp_quat, &alpha); + + double ang = 2.0 * (2.0 * (gamma[idet] - hwp[isamp]) - alpha); + double cang = cos(ang); + double sang = sin(ang); + + long woff = (w_indx * 3 * n_sample) + 3 * isamp; + weights[woff] = cal[idet]; + weights[woff + 1] = cang * eta * cal[idet]; + weights[woff + 2] = -sang * eta * cal[idet] * U_sign; + + return; +} + +__kernel void stokes_weights_I( + int n_det, + long n_sample, + long first_sample, + __global int const * weight_index, + __global double * weights, + __global double const * cal +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + int w_indx = weight_index[idet]; + + long woff = (w_indx * n_sample) + isamp; + weights[woff] = cal[idet]; + + return; +} diff --git a/src/toast/ops/stokes_weights/kernels_opencl.py b/src/toast/ops/stokes_weights/kernels_opencl.py new file mode 100644 index 000000000..91a1454f4 --- /dev/null +++ b/src/toast/ops/stokes_weights/kernels_opencl.py @@ -0,0 +1,211 @@ +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import numpy as np +import pyopencl as cl + +from ... import qarray as qa +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + clear_kernel_deps, +) + + +@kernel(impl=ImplementationType.OPENCL, name="stokes_weights_IQU") +def stokes_weights_IQU_opencl( + quat_index, + quats, + weight_index, + weights, + hwp, + intervals, + epsilon, + gamma, + cal, + IAU, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + if hwp is not None and len(hwp) != len(quats): + hwp = None + + if IAU: + U_sign = -1.0 + else: + U_sign = 1.0 + + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + + devtype = ocl.default_device_type + + if hwp is None: + stokes_weights_IQU = ocl.get_or_build_kernel( + "stokes_weights", + "stokes_weights_IQU", + device_type=devtype, + source=program_file, + ) + else: + stokes_weights_IQU = ocl.get_or_build_kernel( + "stokes_weights", + "stokes_weights_IQU_hwp", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_quats = ocl.mem(quats, device_type=devtype) + dev_weights = ocl.mem(weights, device_type=devtype) + if hwp is not None: + dev_hwp = ocl.mem(hwp, device_type=devtype) + + # Allocate temporary device arrays + dev_quat_index = ocl.mem_to_device(quat_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_quat_index.events) + + dev_weight_index = ocl.mem_to_device(weight_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_weight_index.events) + + dev_epsilon = ocl.mem_to_device(epsilon, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_epsilon.events) + + dev_gamma = ocl.mem_to_device(gamma, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_gamma.events) + + dev_cal = ocl.mem_to_device(cal, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_cal.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"STOKESIQU: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = len(quat_index) + n_samp = quats.shape[1] + for intr in intervals: + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + + if hwp is None: + ev = stokes_weights_IQU( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_quat_index.data, + dev_quats.data, + dev_weight_index.data, + dev_weights.data, + dev_epsilon.data, + dev_gamma.data, + dev_cal.data, + np.float64(U_sign), + np.uint8(IAU), + wait_for=wait_for, + ) + wait_for = [ev] + else: + ev = stokes_weights_IQU( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_quat_index.data, + dev_quats.data, + dev_weight_index.data, + dev_weights.data, + dev_hwp.data, + dev_epsilon.data, + dev_gamma.data, + dev_cal.data, + np.float64(U_sign), + np.uint8(IAU), + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(quat_index, device_type=devtype) + ocl.mem_remove(weight_index, device_type=devtype) + ocl.mem_remove(epsilon, device_type=devtype) + ocl.mem_remove(gamma, device_type=devtype) + ocl.mem_remove(cal, device_type=devtype) + + +@kernel(impl=ImplementationType.OPENCL, name="stokes_weights_I") +def stokes_weights_I_opencl( + weight_index, + weights, + intervals, + cal, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + stokes_weights_I = ocl.get_or_build_kernel( + "stokes_weights", + "stokes_weights_I", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_weights = ocl.mem(weights, device_type=devtype) + + # Allocate temporary device arrays + + dev_weight_index = ocl.mem_to_device(weight_index, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_weight_index.events) + + dev_cal = ocl.mem_to_device(cal, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_cal.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"STOKESI: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = len(weight_index) + for intr in intervals: + first_sample = intr.first + n_samp = intr.last - intr.first + 1 + ev = stokes_weights_I( + ocl.queue(device_type=devtype), + (n_det, n_samp), + None, + np.int32(n_det), + np.int64(n_samp), + np.int64(first_sample), + dev_weight_index.data, + dev_weights.data, + dev_cal.data, + wait_for=wait_for, + ) + wait_for = [ev] + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + # Free temporaries + ocl.mem_remove(weight_index, device_type=devtype) + ocl.mem_remove(cal, device_type=devtype) diff --git a/src/toast/ops/stokes_weights/stokes_weights.py b/src/toast/ops/stokes_weights/stokes_weights.py index 12c3baa2b..2ccd68f8d 100644 --- a/src/toast/ops/stokes_weights/stokes_weights.py +++ b/src/toast/ops/stokes_weights/stokes_weights.py @@ -160,7 +160,9 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): view = self.detector_pointing.view # Expand detector pointing - self.detector_pointing.apply(data, detectors=detectors, use_accel=use_accel) + self.detector_pointing.apply( + data, detectors=detectors, use_accel=use_accel, **kwargs + ) for ob in data.obs: # Get the detectors we are using for this observation @@ -257,6 +259,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): bool(self.IAU), impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) else: stokes_weights_I( @@ -266,6 +270,8 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): cal, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) return @@ -296,6 +302,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): diff --git a/src/toast/pixels.py b/src/toast/pixels.py index 5224a01f8..db8d40a93 100644 --- a/src/toast/pixels.py +++ b/src/toast/pixels.py @@ -20,9 +20,12 @@ accel_enabled, use_accel_jax, use_accel_omp, + use_accel_opencl, ) from .dist import distribute_uniform from .mpi import MPI +if use_accel_opencl: + from .opencl import OpenCL from .timing import GlobalTimers, Timer, function_timer from .utils import ( AlignedF32, @@ -56,7 +59,7 @@ class PixelDistribution(AcceleratorObject): """ def __init__(self, n_pix=None, n_submap=1000, local_submaps=None, comm=None): - super().__init__() + super().__init__(constant=True) self._n_pix = n_pix self._n_submap = n_submap if self._n_submap > self._n_pix: @@ -124,6 +127,7 @@ def clear(self): del self._glob2loc def __del__(self): + super().__del__() self.clear() @property @@ -398,16 +402,29 @@ def alltoallv_info(self): return self._alltoallv_info def _accel_exists(self): + if not hasattr(self, "_glob2loc"): + return False return accel_data_present(self._glob2loc, self._accel_name) def _accel_create(self): self._glob2loc = accel_data_create(self._glob2loc, self._accel_name) def _accel_update_device(self): - self._glob2loc = accel_data_update_device(self._glob2loc, self._accel_name) + if use_accel_omp: + _ = accel_data_update_device(self._glob2loc, self._accel_name) + elif use_accel_jax: + self._glob2loc = accel_data_update_device(self._glob2loc, self._accel_name) + elif use_accel_opencl: + dev_data = accel_data_update_device(self._glob2loc, self._accel_name) + return dev_data.events def _accel_update_host(self): - self._glob2loc = accel_data_update_host(self._glob2loc, self._accel_name) + if use_accel_omp: + _ = accel_data_update_host(self._glob2loc, self._accel_name) + elif use_accel_jax: + self._glob2loc = accel_data_update_host(self._glob2loc, self._accel_name) + elif use_accel_opencl: + return accel_data_update_host(self._glob2loc, self._accel_name) def _accel_reset(self): accel_data_reset(self._glob2loc, self._accel_name) @@ -528,8 +545,7 @@ def clear(self): # we keep the attribute to avoid errors in _accel_exists self.data = None if hasattr(self, "raw"): - if self.accel_exists(): - self.accel_delete() + super().__del__() if self.raw is not None: self.raw.clear() del self.raw @@ -1163,7 +1179,9 @@ def broadcast_map(self, fdata, comm_bytes=10000000): return def _accel_exists(self): - if use_accel_omp: + if not self._dist.accel_exists(): + return False + if use_accel_omp or use_accel_opencl: return accel_data_present(self.raw, self._accel_name) elif use_accel_jax: return accel_data_present(self.data) @@ -1171,31 +1189,48 @@ def _accel_exists(self): return False def _accel_create(self, zero_out=False): - if use_accel_omp: + if not self._dist.accel_exists(): + self._dist.accel_create() + if use_accel_omp or use_accel_opencl: self.raw = accel_data_create(self.raw, self._accel_name, zero_out=zero_out) elif use_accel_jax: self.data = accel_data_create(self.data, zero_out=zero_out) def _accel_update_device(self): + dist_ret = self._dist.accel_update_device() if use_accel_omp: - self.raw = accel_data_update_device(self.raw, self._accel_name) + _ = accel_data_update_device(self.raw, self._accel_name) elif use_accel_jax: self.data = accel_data_update_device(self.data) + elif use_accel_opencl: + dev_data = accel_data_update_device(self.raw, self._accel_name) + evs = list(dev_data.events) + if dist_ret is not None: + evs.extend(dist_ret) + return evs def _accel_update_host(self): + dist_ret = self._dist.accel_update_host() if use_accel_omp: - self.raw = accel_data_update_host(self.raw, self._accel_name) + _ = accel_data_update_host(self.raw, self._accel_name) elif use_accel_jax: self.data = accel_data_update_host(self.data) + elif use_accel_opencl: + evs = accel_data_update_host(self.raw, self._accel_name) + if evs is None: + evs = list() + if dist_ret is not None: + evs.extend(dist_ret) + return evs def _accel_reset(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: accel_data_reset(self.raw, self._accel_name) elif use_accel_jax: accel_data_reset(self.data) def _accel_delete(self): - if use_accel_omp: + if use_accel_omp or use_accel_opencl: self.raw = accel_data_delete(self.raw, self._accel_name) elif use_accel_jax: self.data = accel_data_delete(self.data) diff --git a/src/toast/scripts/toast_env.py b/src/toast/scripts/toast_env.py index 9309778b7..2f217f5cb 100644 --- a/src/toast/scripts/toast_env.py +++ b/src/toast/scripts/toast_env.py @@ -12,7 +12,9 @@ import traceback import toast +from toast.accelerator import use_accel_opencl from toast.mpi import Comm, get_world +from toast.opencl import OpenCL from toast.utils import Environment, Logger, numba_threading_layer @@ -31,6 +33,9 @@ def main(): msg = f"{n_acc} accelerator device(s), using up to " msg += f"{proc_per_acc} processes per device" log.info(msg) + if use_accel_opencl: + ocl = OpenCL() + ocl.info() if mpiworld is None: log.info("Running with one process") else: diff --git a/src/toast/templates/amplitudes.py b/src/toast/templates/amplitudes.py index 66d174774..b0eb8b18a 100644 --- a/src/toast/templates/amplitudes.py +++ b/src/toast/templates/amplitudes.py @@ -14,9 +14,11 @@ accel_data_reset, accel_data_update_device, accel_data_update_host, + accel_wait, accel_enabled, use_accel_jax, use_accel_omp, + use_accel_opencl, ) from ..mpi import MPI from ..utils import ( @@ -183,8 +185,6 @@ def clear(self): are no longer being used and you are about to delete the object. """ - if self.accel_exists(): - self.accel_delete() if hasattr(self, "local"): del self.local self.local = None @@ -203,6 +203,7 @@ def clear(self): self._raw_flags = None def __del__(self): + super().__del__() self.clear() def __repr__(self): @@ -319,15 +320,18 @@ def duplicate(self): # We have no good way to copy between device buffers, # so do this on the host. The duplicate() method is # not used inside the solver loop. - self.accel_update_host() + events = self.accel_update_host() + accel_wait(events) restore = True if self.local is not None: ret.local[:] = self.local if self.local_flags is not None: ret.local_flags[:] = self.local_flags if restore: - self.accel_update_device() - ret.accel_update_device() + events = self.accel_update_device() + accel_wait(events) + events = ret.accel_update_device() + accel_wait(events) return ret @property @@ -718,7 +722,7 @@ def dot(self, other, comm_bytes=10000000): def _accel_exists(self): if self.local is None: return False - if use_accel_omp: + if use_accel_omp or use_accel_opencl: return accel_data_present( self._raw, name=self._accel_name ) and accel_data_present(self._raw_flags, name=self._accel_name) @@ -732,7 +736,7 @@ def _accel_exists(self): def _accel_create(self, zero_out=False): if self.local is None: return - if use_accel_omp: + if use_accel_omp or use_accel_opencl: _ = accel_data_create(self._raw, name=self._accel_name, zero_out=zero_out) _ = accel_data_create( self._raw_flags, name=self._accel_name, zero_out=zero_out @@ -744,27 +748,41 @@ def _accel_create(self, zero_out=False): def _accel_update_device(self): if self.local is None: return + ret = None if use_accel_omp: _ = accel_data_update_device(self._raw, name=self._accel_name) _ = accel_data_update_device(self._raw_flags, name=self._accel_name) elif use_accel_jax: self.local = accel_data_update_device(self.local) self.local_flags = accel_data_update_device(self.local_flags) + elif use_accel_opencl: + ret = list() + dev_data = accel_data_update_device(self._raw, name=self._accel_name) + ret.extend(dev_data.events) + dev_data = accel_data_update_device(self._raw_flags, name=self._accel_name) + ret.extend(dev_data.events) + return ret def _accel_update_host(self): if self.local is None: return + ret = None if use_accel_omp: _ = accel_data_update_host(self._raw, name=self._accel_name) _ = accel_data_update_host(self._raw_flags, name=self._accel_name) elif use_accel_jax: self.local = accel_data_update_host(self.local) self.local_flags = accel_data_update_host(self.local_flags) + elif use_accel_opencl: + ret = list() + ret.extend(accel_data_update_host(self._raw, name=self._accel_name)) + ret.extend(accel_data_update_host(self._raw_flags, name=self._accel_name)) + return ret def _accel_delete(self): if self.local is None: return - if use_accel_omp: + if use_accel_omp or use_accel_opencl: _ = accel_data_delete(self._raw, name=self._accel_name) _ = accel_data_delete(self._raw_flags, name=self._accel_name) elif use_accel_jax: @@ -776,7 +794,7 @@ def _accel_reset_local(self): return # if not self.accel_in_use(): # return - if use_accel_omp: + if use_accel_omp or use_accel_opencl: accel_data_reset(self._raw, name=self._accel_name) elif use_accel_jax: accel_data_reset(self.local) @@ -786,7 +804,7 @@ def _accel_reset_local_flags(self): return # if not self.accel_in_use(): # return - if use_accel_omp: + if use_accel_omp or use_accel_opencl: accel_data_reset(self._raw_flags, name=self._accel_name) elif use_accel_jax: accel_data_reset(self.local_flags) @@ -995,14 +1013,28 @@ def _accel_create(self, zero_out=False): def _accel_update_device(self): if not accel_enabled(): return - for k, v in self._internal.items(): - v.accel_update_device() + ret = None + if use_accel_opencl: + ret = list() + for k, v in self._internal.items(): + ret.extend(v.accel_update_device()) + else: + for k, v in self._internal.items(): + v.accel_update_device() + return ret def _accel_update_host(self): if not accel_enabled(): return - for k, v in self._internal.items(): - v.accel_update_host() + ret = None + if use_accel_opencl: + ret = list() + for k, v in self._internal.items(): + ret.extend(v.accel_update_host()) + else: + for k, v in self._internal.items(): + v.accel_update_host() + return ret def _accel_delete(self): if not accel_enabled(): diff --git a/src/toast/templates/offset/CMakeLists.txt b/src/toast/templates/offset/CMakeLists.txt index 942ffe0c0..f66486d97 100644 --- a/src/toast/templates/offset/CMakeLists.txt +++ b/src/toast/templates/offset/CMakeLists.txt @@ -7,5 +7,7 @@ install(FILES kernels.py kernels_numpy.py kernels_jax.py + kernels_opencl.py + kernels_opencl.cl DESTINATION ${PYTHON_SITE}/toast/templates/offset ) diff --git a/src/toast/templates/offset/kernels.py b/src/toast/templates/offset/kernels.py index e5c2f8ff0..b5b4f412c 100644 --- a/src/toast/templates/offset/kernels.py +++ b/src/toast/templates/offset/kernels.py @@ -11,7 +11,7 @@ from ..._libtoast import ( template_offset_project_signal as libtoast_offset_project_signal, ) -from ...accelerator import ImplementationType, kernel, use_accel_jax +from ...accelerator import ImplementationType, kernel, use_accel_jax, use_accel_opencl from .kernels_numpy import ( offset_add_to_signal_numpy, offset_apply_diag_precond_numpy, @@ -25,6 +25,13 @@ offset_project_signal_jax, ) +if use_accel_opencl: + from .kernels_opencl import ( + offset_add_to_signal_opencl, + offset_apply_diag_precond_opencl, + offset_project_signal_opencl, + ) + @kernel(impl=ImplementationType.DEFAULT) def offset_add_to_signal( @@ -37,6 +44,7 @@ def offset_add_to_signal( det_data, intervals, use_accel=False, + **kwargs, ): """Kernel to accumulate offset amplitudes to timestream data. @@ -84,6 +92,7 @@ def offset_project_signal( amplitude_flags, intervals, use_accel=False, + **kwargs, ): """Kernel to accumulate timestream data into offset amplitudes. @@ -132,6 +141,7 @@ def offset_apply_diag_precond( amplitude_flags, amplitudes_out, use_accel=False, + **kwargs, ): """ Args: diff --git a/src/toast/templates/offset/kernels_opencl.cl b/src/toast/templates/offset/kernels_opencl.cl new file mode 100644 index 000000000..d4de66e37 --- /dev/null +++ b/src/toast/templates/offset/kernels_opencl.cl @@ -0,0 +1,113 @@ +// Copyright (c) 2024-2024 by the parties listed in the AUTHORS file. +// All rights reserved. Use of this source code is governed by +// a BSD-style license that can be found in the LICENSE file. + + +#ifdef cl_khr_int64_base_atomics +#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable +void __attribute__((always_inline)) atomic_add_double( + volatile global double* addr, const double val +) { + union { + ulong u64; + double f64; + } next, expected, current; + current.f64 = *addr; + do { + expected.f64 = current.f64; + next.f64 = expected.f64 + val; + current.u64 = atom_cmpxchg( + (volatile global ulong*)addr, expected.u64, next.u64 + ); + } while(current.u64 != expected.u64); +} +#endif + + +// Kernels + +__kernel void offset_add_to_signal( + long n_sample, + long first_sample, + long step_length, + long amp_offset, + __global double const * amplitudes, + __global unsigned char const * amplitude_flags, + int det_data_index, + __global double * det_data, + unsigned char use_amp_flags +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + size_t doff = det_data_index * n_sample + isamp; + long amp = amp_offset + (long)(isamp / step_length); + + unsigned char amp_check = 0; + if (use_amp_flags) { + amp_check = amplitude_flags[amp]; + } + + if (amp_check == 0) { + det_data[doff] += amplitudes[amp]; + } + return; +} + + +__kernel void offset_project_signal( + long n_sample, + long first_sample, + int det_data_index, + __global double const * det_data, + int det_flag_index, + __global unsigned char const * det_flags, + unsigned char flag_mask, + long step_length, + long amp_offset, + __global double * amplitudes, + __global unsigned char const * amplitude_flags, + unsigned char use_det_flags, + unsigned char use_amp_flags +) { + // Get the global index of this work element + int idet = get_global_id(0); + long isamp = first_sample + get_global_id(1); + + size_t doff = det_data_index * n_sample + isamp; + long amp = amp_offset + (long)(isamp / step_length); + + unsigned char det_check = 0; + if (use_det_flags) { + det_check = det_flags[doff] & flag_mask; + } + unsigned char amp_check = 0; + if (use_amp_flags) { + amp_check = amplitude_flags[amp]; + } + + if ((det_check == 0) && (amp_check == 0)) { + atomic_add_double(&(amplitudes[amp]), det_data[doff]); + } + return; +} + +__kernel void offset_apply_diag_precond( + __global double const * amplitudes_in, + __global double * amplitudes_out, + __global double const * offset_var, + __global unsigned char const * amplitude_flags, + unsigned char use_amp_flags +) { + int iamp = get_global_id(0); + + unsigned char amp_check = 0; + if (use_amp_flags) { + amp_check = amplitude_flags[iamp]; + } + if (amp_check == 0) { + amplitudes_out[iamp] = offset_var[iamp] * amplitudes_in[iamp]; + } + return; +} diff --git a/src/toast/templates/offset/kernels_opencl.py b/src/toast/templates/offset/kernels_opencl.py new file mode 100644 index 000000000..9009b07ad --- /dev/null +++ b/src/toast/templates/offset/kernels_opencl.py @@ -0,0 +1,228 @@ +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import os +import numpy as np +import pyopencl as cl + +from ...accelerator import ImplementationType, kernel +from ...opencl import ( + find_source, + OpenCL, + add_kernel_deps, + get_kernel_deps, + clear_kernel_deps, +) + + +@kernel(impl=ImplementationType.OPENCL, name="offset_add_to_signal") +def offset_add_to_signal_opencl( + step_length, + amp_offset, + n_amp_views, + amplitudes, + amplitude_flags, + data_index, + det_data, + intervals, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + if len(amplitude_flags) == len(amplitudes): + use_amp_flags = np.uint8(1) + else: + use_amp_flags = np.uint8(0) + + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + + devtype = ocl.default_device_type + + kernel = ocl.get_or_build_kernel( + "offset", + "offset_add_to_signal", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_amplitudes = ocl.mem(amplitudes, device_type=devtype) + dev_amplitude_flags = ocl.mem(amplitude_flags, device_type=devtype) + dev_det_data = ocl.mem(det_data, device_type=devtype) + + # No temporaries needed for this kernel. + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"OFFADD: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = det_data.shape[0] + n_samp = det_data.shape[1] + view_offset = amp_offset + for intr, view_amps in zip(intervals, n_amp_views): + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = kernel( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int64(n_samp), + np.int64(first_sample), + np.int64(step_length), + np.int64(view_offset), + dev_amplitudes.data, + dev_amplitude_flags.data, + data_index, + dev_det_data.data, + use_amp_flags, + wait_for=wait_for, + ) + wait_for = [ev] + view_offset += view_amps + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + +@kernel(impl=ImplementationType.OPENCL, name="offset_project_signal") +def offset_project_signal_opencl( + data_index, + det_data, + flag_index, + flag_data, + flag_mask, + step_length, + amp_offset, + n_amp_views, + amplitudes, + amplitude_flags, + intervals, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + if len(amplitude_flags) == len(amplitudes): + use_amp_flags = np.uint8(1) + else: + use_amp_flags = np.uint8(0) + if len(flag_data) == det_data.shape[1]: + use_det_flags = np.uint8(1) + else: + use_det_flags = np.uint8(0) + + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + + devtype = ocl.default_device_type + + kernel = ocl.get_or_build_kernel( + "offset", + "offset_project_signal", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_amplitudes = ocl.mem(amplitudes, device_type=devtype) + dev_amplitude_flags = ocl.mem(amplitude_flags, device_type=devtype) + dev_det_data = ocl.mem(det_data, device_type=devtype) + dev_flag_data = ocl.mem(flag_data, device_type=devtype) + + # No temporaries needed for this kernel. + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"OFFPROJ: {obs_name} got wait_for = {wait_for}", flush=True) + + n_det = det_data.shape[0] + n_samp = det_data.shape[1] + view_offset = amp_offset + for intr, view_amps in zip(intervals, n_amp_views): + first_sample = intr.first + n_intr = intr.last - intr.first + 1 + ev = kernel( + ocl.queue(device_type=devtype), + (n_det, n_intr), + None, + np.int64(n_samp), + np.int64(first_sample), + np.int32(data_index), + dev_det_data.data, + np.int32(flag_index), + dev_flag_data.data, + np.uint8(flag_mask), + np.int64(step_length), + np.int64(view_offset), + dev_amplitudes.data, + dev_amplitude_flags.data, + use_det_flags, + use_amp_flags, + wait_for=wait_for, + ) + wait_for = [ev] + view_offset += view_amps + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, wait_for) + + +@kernel(impl=ImplementationType.OPENCL, name="offset_apply_diag_precond") +def offset_apply_diag_precond_opencl( + offset_var, + amplitudes_in, + amplitude_flags, + amplitudes_out, + use_accel=False, + obs_name=None, + state=None, + **kwargs, +): + if len(amplitude_flags) == len(amplitudes_in): + use_amp_flags = np.uint8(1) + else: + use_amp_flags = np.uint8(0) + + program_file = find_source(os.path.dirname(__file__), "kernels_opencl.cl") + + ocl = OpenCL() + queue = ocl.queue() + devtype = ocl.default_device_type + + kernel = ocl.get_or_build_kernel( + "offset", + "offset_apply_diag_precond", + device_type=devtype, + source=program_file, + ) + + # Get our device arrays + dev_amplitudes_in = ocl.mem(amplitudes_in, device_type=devtype) + dev_amplitude_flags = ocl.mem(amplitude_flags, device_type=devtype) + dev_amplitudes_out = ocl.mem(amplitudes_out, device_type=devtype) + + # Allocate temporaries + dev_offset_var = ocl.mem_to_device(offset_var, device_type=devtype, async_=True) + add_kernel_deps(state, obs_name, dev_offset_var.events) + + # All of the events that our kernels depend on + wait_for = get_kernel_deps(state, obs_name) + # print(f"OFFPREC: {obs_name} got wait_for = {wait_for}", flush=True) + + n_amp = len(amplitudes_in) + ev = kernel( + ocl.queue(device_type=devtype), + (n_amp), + None, + dev_amplitudes_in.data, + dev_amplitudes_out.data, + dev_offset_var.data, + dev_amplitude_flags.data, + use_amp_flags, + wait_for=wait_for, + ) + clear_kernel_deps(state, obs_name) + add_kernel_deps(state, obs_name, ev) diff --git a/src/toast/templates/offset/offset.py b/src/toast/templates/offset/offset.py index 01fb087f6..40c51480f 100644 --- a/src/toast/templates/offset/offset.py +++ b/src/toast/templates/offset/offset.py @@ -724,6 +724,8 @@ def _add_to_signal(self, detector, amplitudes, use_accel=None, **kwargs): ob.intervals[self.view].data, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) # # DEBUGGING @@ -797,6 +799,8 @@ def _project_signal(self, detector, amplitudes, use_accel=None, **kwargs): ob.intervals[self.view].data, impl=implementation, use_accel=use_accel, + obs_name=ob.name, + **kwargs, ) # restore_dev = False @@ -958,6 +962,8 @@ def _apply_precond(self, amplitudes_in, amplitudes_out, use_accel=None, **kwargs amplitudes_out.local, impl=implementation, use_accel=use_accel, + obs_name="None", + **kwargs, ) return @@ -967,6 +973,7 @@ def _implementations(self): ImplementationType.COMPILED, ImplementationType.NUMPY, ImplementationType.JAX, + ImplementationType.OPENCL, ] def _supports_accel(self): diff --git a/src/toast/tests/accelerator.py b/src/toast/tests/accelerator.py index fc75a4b71..fd1e92be5 100644 --- a/src/toast/tests/accelerator.py +++ b/src/toast/tests/accelerator.py @@ -5,6 +5,8 @@ import os import time +from collections import deque + import numpy as np import numpy.testing as nt @@ -20,9 +22,11 @@ accel_data_update_device, accel_data_update_host, accel_enabled, + accel_wait, kernel, use_accel_jax, use_accel_omp, + use_accel_opencl, ) from ..data import Data from ..observation import default_values as defaults @@ -37,6 +41,12 @@ from ..jax.mutableArray import MutableJaxArray, _zero_out_jitted +if use_accel_opencl: + import pyopencl + + from ..opencl import OpenCL, get_kernel_deps, add_kernel_deps + from pyopencl.elementwise import ElementwiseKernel + @trait_docs class AccelOperator(ops.Operator): @@ -52,15 +62,32 @@ class AccelOperator(ops.Operator): def __init__(self, **kwargs): super().__init__(**kwargs) - def _exec(self, data, detectors=None, use_accel=None, **kwargs): + def _exec(self, data, detectors=None, use_accel=None, state=None, **kwargs): for ob in data.obs: if use_accel: # Base class has checked that data listed in our requirements # is present. if use_accel_omp: - # Call compiled code that uses OpenMP target offload to work with this data. + # Call compiled code that uses OpenMP target offload to work + # with this data. test_accel_op_buffer(ob.detdata[self.det_data].data) test_accel_op_array(ob.detdata[self.det_data].data) + elif use_accel_opencl: + # Run a simple elementwise kernel on this + ocl = OpenCL() + ctx = ocl.context() + dev_data = ocl.mem(ob.detdata[self.det_data].data, "signal") + wait_for = get_kernel_deps(state, ob.name) + print(f"Exec wait_for = {wait_for}") + kern = ElementwiseKernel( + ctx, + "double * data", + "data[i] *= 4.0", + "test_kern", + ) + ev = kern(dev_data, wait_for=wait_for) + print(f"Exec kernel ev = {ev}") + add_kernel_deps(state, ob.name, ev) else: ob.detdata[self.det_data].data[:] *= 4 else: @@ -77,6 +104,14 @@ def _requires(self): def _provides(self): return {"detdata": [self.det_data]} + def _implementations(self): + return [ + ImplementationType.DEFAULT, + ImplementationType.COMPILED, + ImplementationType.NUMPY, + ImplementationType.OPENCL, + ] + def _supports_accel(self): return True @@ -120,6 +155,10 @@ def super_numpy(foo, use_accel=False): def awesome_jax(foo, use_accel=False): print(f"JAX (accel={use_accel}): {foo}") + @kernel(ImplementationType.OPENCL, name="my_kernel") + def awesome_opencl(foo, use_accel=False): + print(f"OPENCL (accel={use_accel}): {foo}") + bar = "yes" my_kernel(bar, impl=ImplementationType.DEFAULT) my_kernel(bar, impl=ImplementationType.NUMPY) @@ -127,9 +166,29 @@ def awesome_jax(foo, use_accel=False): my_kernel(bar, impl=ImplementationType.COMPILED, use_accel=True) my_kernel(bar, impl=ImplementationType.JAX) my_kernel(bar, impl=ImplementationType.JAX, use_accel=True) + my_kernel(bar, impl=ImplementationType.OPENCL) + my_kernel(bar, impl=ImplementationType.OPENCL, use_accel=True) + + def test_opencl_info(self): + if not use_accel_opencl: + if self.rank == 0: + print("Not running with OpenCL support, skipping") + return + ocl = OpenCL() + ocl.info() + buf = np.zeros(100, dtype=np.float64) + if ocl.n_gpu > 0: + ocl.mem_create(buf, device_type="gpu") + if ocl.n_cpu > 0: + ocl.mem_create(buf, device_type="cpu") + ocl.mem_dump() + if ocl.n_cpu > 0: + ocl.mem_remove(buf, device_type="cpu") + if ocl.n_gpu > 0: + ocl.mem_remove(buf, device_type="gpu") def test_memory(self): - if not (use_accel_omp or use_accel_jax): + if not (use_accel_omp or use_accel_jax or use_accel_opencl): if self.rank == 0: print("Not running with accelerator support- skipping memory test") return @@ -146,7 +205,11 @@ def test_memory(self): # Copy to device for tname, buffer in data.items(): buffer = accel_data_create(buffer) - buffer = accel_data_update_device(buffer) + if use_accel_opencl: + dev_buffer = accel_data_update_device(buffer) + accel_wait(dev_buffer.events) + else: + buffer = accel_data_update_device(buffer) data[tname] = buffer # Check that it is present @@ -159,7 +222,11 @@ def test_memory(self): # Update device copy for tname, buffer in data.items(): - data[tname] = accel_data_update_device(buffer) + if use_accel_opencl: + dev_buffer = accel_data_update_device(buffer) + accel_wait(dev_buffer.events) + else: + data[tname] = accel_data_update_device(buffer) # Reset host copy for tname, buffer in data.items(): @@ -169,7 +236,11 @@ def test_memory(self): # Update host copy from device for tname, buffer in data.items(): - data[tname] = accel_data_update_host(buffer) + if use_accel_opencl: + events = accel_data_update_host(buffer) + accel_wait(events) + else: + data[tname] = accel_data_update_host(buffer) # Check Values for tname, buffer in data.items(): @@ -184,7 +255,7 @@ def test_memory(self): self.assertFalse(accel_data_present(buffer)) def test_data_stage(self): - if not (use_accel_omp or use_accel_jax): + if not (use_accel_omp or use_accel_jax or use_accel_opencl): if self.rank == 0: print("Not running with accelerator support- skipping data stage test") return @@ -251,7 +322,10 @@ def test_data_stage(self): # Copy data to device data.accel_create(dnames) - data.accel_update_device(dnames) + events = data.accel_update_device(dnames) + print(f"events = {events}", flush=True) + for obsname, obsevs in events.items(): + accel_wait(obsevs) # Clear host buffers (should not impact device data) if not use_accel_jax: @@ -285,7 +359,9 @@ def test_data_stage(self): # print(ob.shared[name]) # Copy back from device - data.accel_update_host(dnames) + events = data.accel_update_host(dnames) + for obsname, obsevs in events.items(): + accel_wait(obsevs) # print("Check original:") # for ob in check_data.obs: @@ -317,8 +393,9 @@ def test_data_stage(self): # Now go and shrink the detector buffers - data.accel_create(dnames) - data.accel_update_device(dnames) + events = data.accel_update_device(dnames) + for obsname, obsevs in events.items(): + accel_wait(obsevs) for check, ob in zip(check_data.obs, data.obs): for itp, (tname, tp) in enumerate(self.types.items()): @@ -327,14 +404,24 @@ def test_data_stage(self): # This will set the host copy to zero and invalidate the device copy ob.detdata[name].change_detectors(ob.local_detectors[0:2]) check.detdata[name].change_detectors(check.local_detectors[0:2]) - ob.detdata[name].accel_update_host() + if use_accel_opencl: + events = ob.detdata[name].accel_update_host() + accel_wait(events) + else: + _ = ob.detdata[name].accel_update_host() # Reset host copy ob.detdata[name][:] = itp + 1 check.detdata[name][:] = itp + 1 # Update device copy - ob.detdata[name].accel_update_device() + if use_accel_opencl: + events = ob.detdata[name].accel_update_device() + accel_wait(events) + else: + _ = ob.detdata[name].accel_update_device() - data.accel_update_host(dnames) + events = data.accel_update_host(dnames) + for obsname, obsevs in events.items(): + accel_wait(obsevs) # Compare for check, ob in zip(check_data.obs, data.obs): @@ -351,7 +438,7 @@ def test_data_stage(self): close_data(data) def test_operator_stage(self): - if not (use_accel_omp or use_accel_jax): + if not (use_accel_omp or use_accel_jax or use_accel_opencl): if self.rank == 0: print("Not running with accelerator support- skipping operator test") return @@ -368,13 +455,25 @@ def test_operator_stage(self): # Stage the data data.accel_create(accel_op.requires()) - data.accel_update_device(accel_op.requires()) + events = data.accel_update_device(accel_op.requires()) + print(f"Start events = {events}") + # for obsname, obsevs in events.items(): + # accel_wait(obsevs) + state = dict() + for obsname, obsevs in events.items(): + add_kernel_deps(state, obsname, obsevs) # Run with staged data - accel_op.apply(data, use_accel=True) + accel_op.apply(data, use_accel=True, state=state) + + # Wait + for obsname, obsevs in state.items(): + accel_wait(obsevs) # Copy out - data.accel_update_host(accel_op.provides()) + events = data.accel_update_host(accel_op.provides()) + for obsname, obsevs in events.items(): + accel_wait(obsevs) # Check for ob in data.obs: @@ -467,10 +566,21 @@ def _accel_create(self): self.data = accel_data_create(self.data) def _accel_update_device(self): - self.data = accel_data_update_device(self.data) + ret = None + if use_accel_opencl: + dev_data = accel_data_update_device(self.data) + ret = dev_data.events + else: + self.data = accel_data_update_device(self.data) + return ret def _accel_update_host(self): - self.data = accel_data_update_host(self.data) + ret = None + if use_accel_opencl: + ret = accel_data_update_host(self.data) + else: + self.data = accel_data_update_host(self.data) + return ret def _accel_delete(self): self.data = accel_data_delete(self.data) diff --git a/src/toast/tests/ops_mapmaker.py b/src/toast/tests/ops_mapmaker.py index 6bd66f467..d55921f98 100644 --- a/src/toast/tests/ops_mapmaker.py +++ b/src/toast/tests/ops_mapmaker.py @@ -12,7 +12,7 @@ from .. import ops as ops from .. import templates -from ..accelerator import accel_enabled +from ..accelerator import accel_enabled, accel_wait from ..observation import default_values as defaults from ..pixels import PixelData, PixelDistribution from ..pixels_io_healpix import write_healpix_fits @@ -156,9 +156,12 @@ def test_offset(self): data.accel_create(pixels.requires()) data.accel_create(weights.requires()) data.accel_create(mapper.requires()) - data.accel_update_device(pixels.requires()) - data.accel_update_device(weights.requires()) - data.accel_update_device(mapper.requires()) + events = data.accel_update_device(pixels.requires()) + accel_wait(events) + events = data.accel_update_device(weights.requires()) + accel_wait(events) + events = data.accel_update_device(mapper.requires()) + accel_wait(events) binner.full_pointing = True mapper.name = "test2" diff --git a/src/toast/tests/ops_mapmaker_utils.py b/src/toast/tests/ops_mapmaker_utils.py index 56eca6ee3..352287232 100644 --- a/src/toast/tests/ops_mapmaker_utils.py +++ b/src/toast/tests/ops_mapmaker_utils.py @@ -9,7 +9,7 @@ from astropy import units as u from .. import ops as ops -from ..accelerator import accel_enabled +from ..accelerator import accel_enabled, accel_wait from ..mpi import MPI from ..noise import Noise from ..observation import default_values as defaults @@ -293,13 +293,17 @@ def test_zmap(self): data.accel_create(pixels.requires()) data.accel_create(weights.requires()) data.accel_create(build_zmap.requires()) - data.accel_update_device(pixels.requires()) - data.accel_update_device(weights.requires()) - data.accel_update_device(build_zmap.requires()) + events = list() + events.extend(data.accel_update_device(pixels.requires())) + events.extend(data.accel_update_device(weights.requires())) + events.extend(data.accel_update_device(build_zmap.requires())) + accel_wait(events) # runs on device build_zmap.apply(data, use_accel=use_accel) # insures everything is back from device - data.accel_update_host(build_zmap.provides()) + events = list() + events.extend(data.accel_update_host(build_zmap.provides())) + accel_wait(events) data.accel_delete(pixels.requires()) data.accel_delete(weights.requires()) data.accel_delete(build_zmap.requires()) diff --git a/src/toast/tests/ops_pixels_healpix.py b/src/toast/tests/ops_pixels_healpix.py index 80ca0266e..14ba0eb38 100644 --- a/src/toast/tests/ops_pixels_healpix.py +++ b/src/toast/tests/ops_pixels_healpix.py @@ -67,6 +67,7 @@ def test_pixels_healpix(self): npix_submap, nside, nest, + True, use_accel, ) bad = pixels[0] >= npix diff --git a/src/toast/tests/ops_pointing_healpix.py b/src/toast/tests/ops_pointing_healpix.py index 469555ec8..47b49032a 100644 --- a/src/toast/tests/ops_pointing_healpix.py +++ b/src/toast/tests/ops_pointing_healpix.py @@ -70,6 +70,7 @@ def test_pointing_matrix_bounds(self): n_pix_submap, nside, True, + True, False, ) stokes_weights_IQU( @@ -156,9 +157,6 @@ def test_pointing_matrix_weights(self): flush=True, ) failed = True - else: - print("Pointing weights agree: {} == {}".format(w1, w2), flush=True) - pass self.assertFalse(failed) return diff --git a/src/toast/tests/ops_scan_map.py b/src/toast/tests/ops_scan_map.py index 3bdd17a25..5015d2be4 100644 --- a/src/toast/tests/ops_scan_map.py +++ b/src/toast/tests/ops_scan_map.py @@ -9,7 +9,7 @@ from astropy import units as u from .. import ops as ops -from ..accelerator import ImplementationType +from ..accelerator import ImplementationType, accel_enabled, accel_wait from ..observation import default_values as defaults from ..pixels import PixelData from ._helpers import close_data, create_fake_sky, create_outdir, create_satellite_data @@ -195,3 +195,71 @@ def test_mask(self): self.assertTrue(ob.detdata["mask_flags"][det, i] == 0) close_data(data) + + def test_scan_pipeline(self): + # Create a fake satellite data set for testing + data = create_satellite_data(self.comm) + + # Create some detector pointing matrices + detpointing = ops.PointingDetectorSimple() + pixels = ops.PixelsHealpix( + nside=64, + create_dist="pixel_dist", + detector_pointing=detpointing, + ) + pixels.apply(data) + weights = ops.StokesWeights( + mode="IQU", + hwp_angle=defaults.hwp_angle, + detector_pointing=detpointing, + ) + weights.apply(data) + + # Create fake polarized sky pixel values locally + create_fake_sky(data, "pixel_dist", "fake_map") + map_data = data["fake_map"] + + # Scan map into timestreams + scanner = ops.ScanMap( + det_data=defaults.det_data, + pixels=pixels.pixels, + weights=weights.weights, + map_key="fake_map", + ) + + # Stage data if needed + if accel_enabled: + data.accel_create(scanner.requires()) + events = data.accel_update_device(scanner.requires()) + accel_wait(events) + + kernel_state = {"state": None} + scanner.apply(data, use_accel=accel_enabled, **kernel_state) + + # Copy back if needed + if accel_enabled: + events = data.accel_update_host(scanner.provides()) + accel_wait(events) + + # Manual check of the projection of map values to timestream + for ob in data.obs: + for det in ob.select_local_detectors( + flagmask=defaults.det_mask_invalid + ): + wt = ob.detdata[weights.weights][det] + local_sm, local_pix = data["pixel_dist"].global_pixel_to_submap( + ob.detdata[pixels.pixels][det] + ) + for i in range(ob.n_local_samples): + if local_pix[i] < 0: + continue + val = 0.0 + for j in range(3): + val += ( + wt[i, j] * map_data.data[local_sm[i], local_pix[i], j] + ) + np.testing.assert_almost_equal( + val, ob.detdata[defaults.det_data][det, i] + ) + + close_data(data) diff --git a/src/toast/tests/template_offset.py b/src/toast/tests/template_offset.py index 57c6815de..10423b1f4 100644 --- a/src/toast/tests/template_offset.py +++ b/src/toast/tests/template_offset.py @@ -9,7 +9,12 @@ from astropy import units as u from .. import ops -from ..accelerator import ImplementationType, accel_data_table, accel_enabled +from ..accelerator import ( + ImplementationType, + accel_data_table, + accel_enabled, + accel_wait, +) from ..observation import default_values as defaults from ..templates import AmplitudesMap, Offset from ..utils import rate_from_times @@ -135,9 +140,11 @@ def test_accel(self): } data.accel_create(data_names) - data.accel_update_device(data_names) + events = data.accel_update_device(data_names) + accel_wait(events) amps.accel_create("Offset") - amps.accel_update_device() + events = amps.accel_update_device() + accel_wait(events) # Project. for det in tmpl.detectors(): @@ -149,8 +156,10 @@ def test_accel(self): for ob in data.obs: tmpl.project_signal(det, amps, use_accel=True) - data.accel_update_host(data_names) - amps.accel_update_host() + events = data.accel_update_host(data_names) + accel_wait(events) + events = amps.accel_update_host() + accel_wait(events) # Verify for ob in data.obs: diff --git a/src/toast/traits.py b/src/toast/traits.py index 35bdc45b0..0330c9360 100644 --- a/src/toast/traits.py +++ b/src/toast/traits.py @@ -29,7 +29,12 @@ signature_has_traits, ) -from .accelerator import ImplementationType, use_accel_jax, use_accel_omp +from .accelerator import ( + ImplementationType, + use_accel_jax, + use_accel_omp, + use_accel_opencl, +) from .trait_utils import fix_quotes, string_to_trait, trait_to_string from .utils import Logger, import_from_name, object_fullname @@ -267,7 +272,7 @@ class constructor. kernel_implementation = UseEnum( ImplementationType, default_value=ImplementationType.DEFAULT, - help="Which kernel implementation to use (DEFAULT, COMPILED, NUMPY, JAX).", + help="Which kernel implementation to use (DEFAULT, COMPILED, NUMPY, JAX, OPENCL).", ) def __init__(self, **kwargs): @@ -322,7 +327,13 @@ def select_kernels(self, use_accel=None): if use_accel is None: return ImplementationType.DEFAULT, False elif use_accel: - if use_accel_jax: + if use_accel_opencl: + if ImplementationType.OPENCL not in impls: + msg = f"OPENCL accelerator use is enabled, " + msg += f"but not supported by {self.name}" + raise RuntimeError(msg) + return ImplementationType.OPENCL, True + elif use_accel_jax: if ImplementationType.JAX not in impls: msg = f"JAX accelerator use is enabled, " msg += f"but not supported by {self.name}"