Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: __getitem__ logic for MLIR backend #779

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sparse/mlir_backend/_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,12 @@ def __del__(self):
for field in self._obj.get__fields_():
free_memref(field)

def __getitem__(self, key) -> "Tensor":
# imported lazily to avoid cyclic dependency
from ._ops import getitem

return getitem(self, key)

@_hold_self_ref_in_ret
def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
return self._obj.to_sps(self.shape)
Expand Down
117 changes: 111 additions & 6 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ctypes
from types import EllipsisType

import mlir.execution_engine
import mlir.passmanager
Expand Down Expand Up @@ -85,12 +86,39 @@ def get_reshape_module(
def reshape(a, shape):
return tensor.reshape(out_tensor_type, a, shape)

reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "reshape_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "reshape_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


@fn_cache
def get_slice_module(
in_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
offsets: tuple[int, ...],
sizes: tuple[int, ...],
strides: tuple[int, ...],
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(in_tensor_type)
def getitem(a):
return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides)

getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "getitem_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "getitem_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])

Expand Down Expand Up @@ -195,3 +223,80 @@ def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) ->
)

return Tensor(ret_obj, shape=shape)


def _add_missing_dims(key: tuple, ndim: int) -> tuple:
if len(key) < ndim and Ellipsis not in key:
return key + (...,)
return key


def _expand_ellipsis(key: tuple, ndim: int) -> tuple:
if Ellipsis in key:
if len([e for e in key if e is Ellipsis]) > 1:
raise Exception(f"Ellipsis should be used once: {key}")
to_expand = ndim - len(key) + 1
if to_expand <= 0:
raise Exception(f"Invalid use of Ellipsis in {key}")
idx = key.index(Ellipsis)
return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :]
return key


def _decompose_slices(
key: tuple,
shape: tuple[int, ...],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
offsets = []
sizes = []
strides = []

for key_elem, size in zip(key, shape, strict=False):
if isinstance(key_elem, slice):
offset = key_elem.start if key_elem.start is not None else 0
size = key_elem.stop - offset if key_elem.stop is not None else size - offset
stride = key_elem.step if key_elem.step is not None else 1
elif isinstance(key_elem, int):
offset = key_elem
size = key_elem + 1
stride = 1
offsets.append(offset)
sizes.append(size)
strides.append(stride)

return tuple(offsets), tuple(sizes), tuple(strides)


def _get_new_shape(sizes, strides) -> tuple[int, ...]:
return tuple(size // stride for size, stride in zip(sizes, strides, strict=False))


def getitem(
x: Tensor,
key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...],
) -> Tensor:
if not isinstance(key, tuple):
key = (key,)
if None in key:
raise Exception(f"Lazy indexing isn't supported: {key}")

ret_obj = x._format_class()

key = _add_missing_dims(key, x.ndim)
key = _expand_ellipsis(key, x.ndim)
offsets, sizes, strides = _decompose_slices(key, x.shape)

new_shape = _get_new_shape(sizes, strides)
out_tensor_type = x._obj.get_tensor_definition(new_shape)

slice_module = get_slice_module(
x._obj.get_tensor_definition(x.shape),
out_tensor_type,
offsets,
sizes,
strides,
)

slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg())

return Tensor(ret_obj, shape=out_tensor_type.shape)
36 changes: 36 additions & 0 deletions sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,39 @@ def test_broadcast_to(dtype):

assert result.format == "csr"
np.testing.assert_allclose(result.todense(), np.repeat(np_arr[np.newaxis], 3, axis=0))


@pytest.mark.skip(reason="https://discourse.llvm.org/t/illegal-operation-when-slicing-csr-csc-coo-tensor/81404")
@parametrize_dtypes
@pytest.mark.parametrize(
"index",
[
0,
(2,),
(2, 3),
(..., slice(0, 4, 2)),
(1, slice(1, None, 1)),
# TODO: For below cases we need an update to ownership mechanism.
# `tensor[:, :]` returns the same memref that was passed.
# The mechanism sees the result as MLIR-allocated and frees
# it, while it still can be owned by SciPy/NumPy causing a
# segfault when it frees SciPy/NumPy managed memory.
# ...,
# slice(None),
# (slice(None), slice(None)),
],
)
def test_indexing_2d(rng, dtype, index):
SHAPE = (20, 30)
DENSITY = 0.5

for format in ["csr", "csc", "coo"]:
arr = sps.random_array(SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng)
arr.sum_duplicates()

tensor = sparse.asarray(arr)

actual = tensor[index].to_scipy_sparse()
expected = arr.todense()[index]

np.testing.assert_array_equal(actual.todense(), expected)
Loading