Skip to content

Improve sharding propagation for triangle updates outgoing #5901

Description

@wujingyue

Repro: #5890

$ mpirun -np 1 -x NVFUSER_DUMP=pre_segmenter_logging pytest tests/python/multidevice/test_alphafold3.py -k outgoing --only-mpi -vs

The code of interest:

match direction:
case Direction.OUTGOING:
# z_out = einsum("bikc,bjkc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j]
case Direction.INCOMING:
# z_out = einsum("bkic,bkjc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j]
z = fd.ops.matmul(a, b) # [b, c, i, j]
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c]

Image

The current heuristic for the forward propagation is to prefer the second input (usually the weight). Therefore, in the einsum output, j is sharded by DIDy not DIDx. This breaks the backprop from z_in (i by DIDy and j by DIDx) to the einsum output, because z_in wants j to be sharded by DIDx instead.

By the way, this is not a problem for "incoming" mode. Following the current heuristic, the einsum output does have j sharded on DIDx.

Image

cc @DejunL

Metadata

Metadata

Assignees

No one assigned

    Labels

    Fields

    No fields configured for Enhancement.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions