[Feature] Support unaligned barrier sync#2295
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds an optional ChangesUnaligned Barrier Synchronization
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| << PrintExpr(args[2]); | ||
| if (args.size() == 4) { | ||
| this->stream << ", " | ||
| << (GetBoolImm(args[3], "storage_sync aligned") ? "true" |
There was a problem hiding this comment.
what's "storage_sync aligned"?
53eacd3 to
9de3806
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
testing/python/transform/test_tilelang_transform_layout_inference.py (1)
185-187: ⚡ Quick winPrefer structural layout markers over exact local array size literals.
Asserting
values[16]vsvalues[32]is fragile. Keep the fallback/partition intent check structural (e.g., scalar fragment path present and vector-pack fallback absent) to reduce false failures.Based on learnings: For Python tests in
testing/python/transform, assertions should target structural behavior rather than specific numeric literals.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/transform/test_tilelang_transform_layout_inference.py` around lines 185 - 187, Replace fragile literal-size assertions on kernel_source with structural checks: instead of asserting "signed char values[16]" and not "signed char values[32]", verify the scalar-fragment code path exists (e.g., assert presence of the scalar fragment identifier or code block that handles single-element processing) and verify the vector-packed fallback is absent by keeping the existing check against "make_longlong4(" (or asserting absence of any vector-pack helper like "make_longlong" patterns). Update the test to assert these structural markers on kernel_source rather than specific numeric array sizes.testing/python/transform/test_tilelang_transform_thread_sync.py (1)
112-113: ⚡ Quick winAvoid hard-coding barrier IDs/count literals in transform-pass assertions.
These checks are brittle to unrelated pass/internal allocation changes. Assert structural behavior (aligned vs unaligned form and occurrence patterns) instead of exact
(3, 64/128)literals.Suggested assertion style
- assert "tl::__sync_thread_partial<false>(3, 128);" in src, src + assert "tl::__sync_thread_partial<false>(" in src, src - assert "tl::__sync_thread_partial(3, 64);" in src, src + assert "tl::__sync_thread_partial(" in src, src assert "tl::__sync_thread_partial<false>(" not in src, src - assert src.count("tl::__sync_thread_partial(3, 64);") == 2, src + assert src.count("tl::__sync_thread_partial(") == 2, src assert "tl::__sync_thread_partial<false>(" not in src, srcBased on learnings: For Python tests of the tilelang transform passes, focus assertions on structural patterns in generated source and avoid relying on specific numeric literals.
Also applies to: 128-130, 146-147
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/transform/test_tilelang_transform_thread_sync.py` around lines 112 - 113, The assertion currently checks for a hard-coded barrier literal "tl::__sync_thread_partial<false>(3, 128);" which is brittle; update the test in test_tilelang_transform_thread_sync.py to assert structural patterns instead: check that the generated source (src) contains the function call name "tl::__sync_thread_partial<false>" and then assert whether the call appears in the aligned form (e.g., with a power-of-two second argument pattern) or unaligned form by using a regex or substring checks for the surrounding token patterns rather than exact numeric literals; replace the exact-match assertions (the one shown and the ones at lines ~128-130 and ~146-147) with these structural/occurrence pattern checks on src.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/cuda/codegen/codegen_cuda.cc`:
- Around line 1425-1439: The optional aligned selector (args[3]) is downcast
with Downcast<Bool> without checking it's a boolean immediate; add an ICHECK
before the downcast to ensure args.size()==4 implies args[3] is a Bool immediate
(e.g., check args[3].dtype().is_bool() and/or that it is a BoolImm) so the code
in the tl::__sync_thread_partial emission path (the block using args,
Downcast<Bool>, PrintExpr and this->stream) fails with a clear diagnostic
instead of a generic downcast failure; apply the same guard in the similar
section around lines 2404-2415 where Downcast<Bool> is used.
---
Nitpick comments:
In `@testing/python/transform/test_tilelang_transform_layout_inference.py`:
- Around line 185-187: Replace fragile literal-size assertions on kernel_source
with structural checks: instead of asserting "signed char values[16]" and not
"signed char values[32]", verify the scalar-fragment code path exists (e.g.,
assert presence of the scalar fragment identifier or code block that handles
single-element processing) and verify the vector-packed fallback is absent by
keeping the existing check against "make_longlong4(" (or asserting absence of
any vector-pack helper like "make_longlong" patterns). Update the test to assert
these structural markers on kernel_source rather than specific numeric array
sizes.
In `@testing/python/transform/test_tilelang_transform_thread_sync.py`:
- Around line 112-113: The assertion currently checks for a hard-coded barrier
literal "tl::__sync_thread_partial<false>(3, 128);" which is brittle; update the
test in test_tilelang_transform_thread_sync.py to assert structural patterns
instead: check that the generated source (src) contains the function call name
"tl::__sync_thread_partial<false>" and then assert whether the call appears in
the aligned form (e.g., with a power-of-two second argument pattern) or
unaligned form by using a regex or substring checks for the surrounding token
patterns rather than exact numeric literals; replace the exact-match assertions
(the one shown and the ones at lines ~128-130 and ~146-147) with these
structural/occurrence pattern checks on src.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 340d38ad-aae7-4cd9-ba89-1fa9ba7299aa
📒 Files selected for processing (11)
src/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cutedsl.ccsrc/op/parallel.ccsrc/tl_templates/cuda/common.hsrc/transform/thread_storage_sync.cctesting/python/language/test_tilelang_language_sync_threads.pytesting/python/transform/test_tilelang_transform_layout_inference.pytesting/python/transform/test_tilelang_transform_thread_sync.pytilelang/contrib/cutedsl/reduce.pytilelang/contrib/cutedsl/utils.pytilelang/language/builtin.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/contrib/cutedsl/reduce.py
9de3806 to
ffcd43d
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tilelang/language/builtin.py (1)
361-377: ⚡ Quick winValidate
readtype intma_store_waitbefore lowering
readis documented asbool, but currently any value is forwarded into the intrinsic. This can surface later as codegen downcast failures instead of a clear frontend error. Add a runtime bool check (same pattern used byalignedin this file).Suggested patch
def tma_store_wait(count: int = 0, read: bool = True): @@ - return tirx.call_intrin("handle", tirx.op.Op.get("tl.tma_store_wait"), count, read) + if not isinstance(read, bool): + raise TypeError(f"Expect read to be bool, but got {type(read)}.") + return tirx.call_intrin("handle", tirx.op.Op.get("tl.tma_store_wait"), count, read)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tilelang/language/builtin.py` around lines 361 - 377, The tma_store_wait function forwards the documented bool parameter read directly into the intrinsic which can cause downstream codegen type errors; add a runtime type check similar to the aligned validation in this file: verify isinstance(read, bool) and raise a TypeError with a clear message if not, before the existing return that calls tirx.call_intrin("handle", tirx.op.Op.get("tl.tma_store_wait"), count, read) so only booleans are lowered.testing/python/language/test_tilelang_language_sync_threads.py (1)
52-52: ⚡ Quick winMake source assertions less brittle to formatting-only codegen changes
These checks currently depend on exact spacing/text rendering. Consider asserting stable tokens (callee + unaligned marker) instead of full-line exact string equality.
Suggested patch
- assert "tl::__sync_thread_partial(1, 128, false);" in src, src + assert "tl::__sync_thread_partial(" in src and ", false);" in src, src @@ - assert "tl::__sync_thread_partial(1, 128, false);" in src, src + assert "tl::__sync_thread_partial(" in src and ", false);" in src, src @@ - assert "tl::__named_barrier_arrive<false>(2, 128);" in src, src + assert "tl::__named_barrier_arrive<false>(" in src, srcAlso applies to: 64-64, 76-76
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/language/test_tilelang_language_sync_threads.py` at line 52, The assertion is brittle because it matches exact spacing of the generated line ("tl::__sync_thread_partial(1, 128, false);"); update the three assertions in test_tilelang_language_sync_threads.py (the ones checking for tl::__sync_thread_partial at lines ~52, ~64, ~76) to instead assert on stable tokens such as the callee name and key arguments/markers (for example assert "tl::__sync_thread_partial" in src and assert "false" in the same context, or use a simple regex matching tl::__sync_thread_partial\\s*\\(.*1.*128.*false.*\\)); locate the assertions by the string "tl::__sync_thread_partial" and replace exact-string checks with token-based or regex checks so formatting-only codegen changes won’t break the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@testing/python/language/test_tilelang_language_sync_threads.py`:
- Line 52: The assertion is brittle because it matches exact spacing of the
generated line ("tl::__sync_thread_partial(1, 128, false);"); update the three
assertions in test_tilelang_language_sync_threads.py (the ones checking for
tl::__sync_thread_partial at lines ~52, ~64, ~76) to instead assert on stable
tokens such as the callee name and key arguments/markers (for example assert
"tl::__sync_thread_partial" in src and assert "false" in the same context, or
use a simple regex matching
tl::__sync_thread_partial\\s*\\(.*1.*128.*false.*\\)); locate the assertions by
the string "tl::__sync_thread_partial" and replace exact-string checks with
token-based or regex checks so formatting-only codegen changes won’t break the
test.
In `@tilelang/language/builtin.py`:
- Around line 361-377: The tma_store_wait function forwards the documented bool
parameter read directly into the intrinsic which can cause downstream codegen
type errors; add a runtime type check similar to the aligned validation in this
file: verify isinstance(read, bool) and raise a TypeError with a clear message
if not, before the existing return that calls tirx.call_intrin("handle",
tirx.op.Op.get("tl.tma_store_wait"), count, read) so only booleans are lowered.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e680c950-46ba-4827-b4b4-29d2662652a4
📒 Files selected for processing (7)
src/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cutedsl.ccsrc/tl_templates/cuda/common.htesting/python/language/test_tilelang_language_sync_threads.pytilelang/contrib/cutedsl/reduce.pytilelang/contrib/cutedsl/utils.pytilelang/language/builtin.py
🚧 Files skipped from review as they are similar to previous changes (4)
- src/backend/cuda/codegen/codegen_cutedsl.cc
- tilelang/contrib/cutedsl/reduce.py
- tilelang/contrib/cutedsl/utils.py
- src/tl_templates/cuda/common.h
ffcd43d to
d65e9f8
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_sync_threads.py (1)
52-52: ⚡ Quick winMake codegen assertions less formatting-fragile.
These checks are good, but exact full-statement matches are brittle to whitespace/template formatting changes. Prefer asserting stable semantic substrings (e.g., intrinsic name +
falsespecialization).Proposed adjustment
- assert "tl::__sync_thread_partial(1, 128, false);" in src, src + assert "tl::__sync_thread_partial(" in src and "false" in src, src ... - assert "tl::__sync_thread_partial(1, 128, false);" in src, src + assert "tl::__sync_thread_partial(" in src and "false" in src, src ... - assert "tl::__named_barrier_arrive<false>(2, 128);" in src, src + assert "tl::__named_barrier_arrive<false>" in src, srcAlso applies to: 64-64, 76-76
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@testing/python/language/test_tilelang_language_sync_threads.py` at line 52, The assertions that currently check the exact generated statement string (e.g., "tl::__sync_thread_partial(1, 128, false);") are too formatting-fragile; change them to assert on stable semantic substrings instead — for example verify that "tl::__sync_thread_partial(" is in src and that the specialization flag substring ", false" (or "false" near the intrinsic) is present (or assert both "tl::__sync_thread_partial(" and "false" are in src) for the three occurrences that currently match the full-statement (the checks referencing tl::__sync_thread_partial at the three assertion sites); update those assertions to use substring membership checks rather than exact full-statement equality.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tilelang/language/builtin.py`:
- Around line 959-966: In the sync lowering logic (the block that handles
barrier_id, arrive_count and aligned — look for variables barrier_id,
arrive_count, aligned and the T.sync_threads path), add a validation that if
aligned is True and arrive_count is provided while barrier_id is None, you raise
a clear error (TypeError/ValueError) and do not lower; this prevents
interpreting a single positional arg as barrier_id. Update the conditional so
you only append arrive_count when barrier_id is present (or not aligned) and
ensure the error message references both arrive_count and barrier_id to guide
callers.
---
Nitpick comments:
In `@testing/python/language/test_tilelang_language_sync_threads.py`:
- Line 52: The assertions that currently check the exact generated statement
string (e.g., "tl::__sync_thread_partial(1, 128, false);") are too
formatting-fragile; change them to assert on stable semantic substrings instead
— for example verify that "tl::__sync_thread_partial(" is in src and that the
specialization flag substring ", false" (or "false" near the intrinsic) is
present (or assert both "tl::__sync_thread_partial(" and "false" are in src) for
the three occurrences that currently match the full-statement (the checks
referencing tl::__sync_thread_partial at the three assertion sites); update
those assertions to use substring membership checks rather than exact
full-statement equality.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e9ba6796-6ee4-4894-80a3-161fbac6983b
📒 Files selected for processing (7)
src/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cutedsl.ccsrc/tl_templates/cuda/common.htesting/python/language/test_tilelang_language_sync_threads.pytilelang/contrib/cutedsl/reduce.pytilelang/contrib/cutedsl/utils.pytilelang/language/builtin.py
🚧 Files skipped from review as they are similar to previous changes (4)
- tilelang/contrib/cutedsl/reduce.py
- src/backend/cuda/codegen/codegen_cutedsl.cc
- tilelang/contrib/cutedsl/utils.py
- src/tl_templates/cuda/common.h
| if barrier_id is not None or not aligned: | ||
| if barrier_id is None: | ||
| barrier_id = 0 | ||
| args.append(barrier_id) | ||
| if arrive_count is not None: | ||
| if arrive_count is not None or not aligned: | ||
| if arrive_count is None: | ||
| arrive_count = 0 | ||
| args.append(arrive_count) |
There was a problem hiding this comment.
Disallow arrive_count without barrier_id in aligned mode.
T.sync_threads(arrive_count=...) is currently accepted and lowers with a single positional arg, which is ambiguous and can map to the wrong barrier id at codegen time.
Proposed fix
def sync_threads(barrier_id: int = None, arrive_count: int = None, aligned: bool = True):
"""Synchronize all threads in a block."""
args = []
if not isinstance(aligned, bool):
raise TypeError(f"Expect aligned to be bool, but got {type(aligned)}.")
+ if aligned and barrier_id is None and arrive_count is not None:
+ raise ValueError("T.sync_threads(arrive_count=...) requires barrier_id.")
if barrier_id is not None or not aligned:
if barrier_id is None:
barrier_id = 0
args.append(barrier_id)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tilelang/language/builtin.py` around lines 959 - 966, In the sync lowering
logic (the block that handles barrier_id, arrive_count and aligned — look for
variables barrier_id, arrive_count, aligned and the T.sync_threads path), add a
validation that if aligned is True and arrive_count is provided while barrier_id
is None, you raise a clear error (TypeError/ValueError) and do not lower; this
prevents interpreting a single positional arg as barrier_id. Update the
conditional so you only append arrive_count when barrier_id is present (or not
aligned) and ensure the error message references both arrive_count and
barrier_id to guide callers.
Summary by CodeRabbit
New Features
alignedparameter to sync primitives (sync_threads()andnamed_barrier_arrive()), letting callers choose aligned vs unaligned barrier behavior (default: True). Generated kernels will reflect the chosen form.Tests