Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: experimental sharding backend #1544

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

polvalente
Copy link
Contributor

Adds a proof-of-concept implementation for sharding as an Nx meta-compiler.

In the current proposal, we shard inputs according to an arbitrary slicing configuration, and the compiler then does its best to propagate those slices through to the output. The compiler can then build a separate {args, function, reducer} tuple for each of the output shards, where:

  • args: is a version of the input arguments that is sliced according to which data section is required for that specific output data section
  • function: is a new compilation of the input function based on the new sliced arguments
  • reducer: is a function that is responsible for inserting the result shard into the correct place in an output accumulator tensor.

The following example showcases how we can shard the example function into 4 separate shards. This happens because There are 2 shards for arg1, and only the first axis of arg0 is able to be sharded, due to the other 2 axes being connected to contracting axes in the dot product.

arg0_sharding = %{} # inputs are taken to be fully sharded if no specification is given
arg1_sharding = %{4 => [0..0, 1..1]}

Nx.default_backend(Nx.BinaryBackend)

fun = fn l, r ->
  x = Nx.add(l, Nx.tensor([[1]]))
  x = Nx.transpose(x, axes: [0, 2, 1])
  y = Nx.subtract(r, 1)
  y = Nx.squeeze(y, axes: [0, 1])
  Nx.dot(x, [2, 1], y, [1, 0])
end

# fun = &Nx.dot(&1, [1, 2], &2, [1, 0])
# fun = &Nx.add(&1, &2)

inputs = [
  Nx.iota({2, 2, 3}, type: :f32),
  Nx.add(Nx.iota({1, 1, 3, 2, 2}), 10)
]

{output_holder, shards} =
  Nx.Defn.jit_apply(
    fun,
    inputs,
    compiler: Nx.Defn.ShardingCompiler,
    sharding_config: [arg0_sharding, arg1_sharding],
    sharding_compiler: Nx.Defn.Evaluator,
    sharding_compiler_options: []
  )

sharded_result =
  shards
  |> Task.async_stream(fn {arg, fun, caster} ->
    dbg(self())
    {fun.(arg), caster}
  end)
  |> Enum.reduce(output_holder, fn {:ok, {result, caster}}, acc ->
    caster.(result, acc)
  end)
  |> IO.inspect()

# Ensure that the sharded result is the same as the result for the function applied to the unsharded inputs
IO.inspect(Nx.equal(sharded_result, apply(fun, inputs)) |> Nx.all() |> Nx.to_number() |> Kernel.==(1))

@polvalente polvalente force-pushed the pv-feat/experimental-sharding-backend branch from e67a399 to 68565bc Compare October 11, 2024 01:35
@polvalente polvalente self-assigned this Oct 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants