diff --git a/python/tvm/tirx/transform/common.py b/python/tvm/tirx/transform/common.py index d90903daf967..d7ebd557af0e 100644 --- a/python/tvm/tirx/transform/common.py +++ b/python/tvm/tirx/transform/common.py @@ -36,7 +36,6 @@ from tvm.tirx.stmt_functor import StmtExprMutator, StmtMutator -# FIXME: this pass does not replace var in the shape/layout of a buffer class BufferReplacer(StmtExprMutator): """ Replace buffer with another buffer. @@ -63,6 +62,10 @@ def mutate_buffer(self, buffer: Buffer): self.buffer_attr_var_mutated = False new_data = self.visit_expr(buffer.data) new_shape = [self.visit_expr(expr) for expr in buffer.shape] + new_strides = [self.visit_expr(expr) for expr in buffer.strides] + new_elem_offset = ( + self.visit_expr(buffer.elem_offset) if buffer.elem_offset is not None else None + ) if isinstance(buffer.layout, TileLayout): new_shard = [] new_replicate = [] @@ -90,8 +93,8 @@ def mutate_buffer(self, buffer: Buffer): buffer.dtype, buffer.name, new_data, - buffer.strides, - buffer.elem_offset, + new_strides, + new_elem_offset, buffer.scope(), buffer.data_alignment, buffer.offset_factor, diff --git a/tests/python/tirx/test_op.py b/tests/python/tirx/test_op.py index 480e6cd3ddbc..4c417033d697 100644 --- a/tests/python/tirx/test_op.py +++ b/tests/python/tirx/test_op.py @@ -57,6 +57,21 @@ def test_buffer_replacer_no_shared_default(): assert len(r2.buffer_map) == 0 +def test_buffer_replacer_replaces_strides_and_elem_offset(): + """Vars in buffer strides/elem_offset must be replaced, not passed through.""" + from tvm.tirx import BufferStore, Var + from tvm.tirx.transform.common import BufferReplacer + + n = Var("n", "int32") + m = Var("m", "int32") + A = decl_buffer((64,), "float32", strides=[n], elem_offset=n) + store = BufferStore(A, 1.0, [0]) + + new = BufferReplacer(var_map={n: m})(store) + assert new.buffer.strides[0].same_as(m) + assert new.buffer.elem_offset.same_as(m) + + def test_gemm_async_partial_scale_factor(): """Regression test for F7: gemm_async must reject partial scale factors.""" from tvm.tirx.script.builder.tirx import gemm_async