Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/tvm/tirx/transform/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from tvm.tirx.stmt_functor import StmtExprMutator, StmtMutator


# FIXME: this pass does not replace var in the shape/layout of a buffer
class BufferReplacer(StmtExprMutator):
"""
Replace buffer with another buffer.
Expand All @@ -63,6 +62,10 @@ def mutate_buffer(self, buffer: Buffer):
self.buffer_attr_var_mutated = False
new_data = self.visit_expr(buffer.data)
new_shape = [self.visit_expr(expr) for expr in buffer.shape]
new_strides = [self.visit_expr(expr) for expr in buffer.strides]
Comment thread
guan404ming marked this conversation as resolved.
new_elem_offset = (
self.visit_expr(buffer.elem_offset) if buffer.elem_offset is not None else None
)
if isinstance(buffer.layout, TileLayout):
new_shard = []
new_replicate = []
Expand Down Expand Up @@ -90,8 +93,8 @@ def mutate_buffer(self, buffer: Buffer):
buffer.dtype,
buffer.name,
new_data,
buffer.strides,
buffer.elem_offset,
new_strides,
new_elem_offset,
buffer.scope(),
buffer.data_alignment,
buffer.offset_factor,
Expand Down
15 changes: 15 additions & 0 deletions tests/python/tirx/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ def test_buffer_replacer_no_shared_default():
assert len(r2.buffer_map) == 0


def test_buffer_replacer_replaces_strides_and_elem_offset():
"""Vars in buffer strides/elem_offset must be replaced, not passed through."""
from tvm.tirx import BufferStore, Var
from tvm.tirx.transform.common import BufferReplacer

n = Var("n", "int32")
m = Var("m", "int32")
A = decl_buffer((64,), "float32", strides=[n], elem_offset=n)
store = BufferStore(A, 1.0, [0])

new = BufferReplacer(var_map={n: m})(store)
assert new.buffer.strides[0].same_as(m)
assert new.buffer.elem_offset.same_as(m)


def test_gemm_async_partial_scale_factor():
"""Regression test for F7: gemm_async must reject partial scale factors."""
from tvm.tirx.script.builder.tirx import gemm_async
Expand Down
Loading