Skip to content

[CK_TILE] Sparge attention#3727

Draft
gino-lu wants to merge 18 commits into
developfrom
ginolu/sparge_attention
Draft

[CK_TILE] Sparge attention#3727
gino-lu wants to merge 18 commits into
developfrom
ginolu/sparge_attention

Conversation

@gino-lu
Copy link
Copy Markdown
Contributor

@gino-lu gino-lu commented Apr 15, 2026

Proposed changes

Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

gino-lu and others added 18 commits March 19, 2026 23:28
- Add sparge_tool.hpp: host-side Sparge block-map builder (mean-sim
  scoring, CDF/topk selection) and VSA delta-LUT converter.
- Add test_sparge_jenga_sparse_attn.cpp and
  test_sparge_vsa_sparse_attn.cpp as end-to-end demos.
- Update CMakeLists.txt to register both new executables.

Note: block size is currently fixed at 128; flexible block size
support is not yet addressed.
Add bm0 field to fmha_jenga_fwd_traits so callers can specify the
preferred Q-tile size. Codegen now emits separate tile configs for
bm0=64 (sparge blockmap) and bm0=128 (original), with CppConstraint
guards to select the right kernel at runtime.

End-to-end test passes for both jenga and vsa paths. Performance is
known to be suboptimal at this stage; tile sizes and warp counts for
the bm0=64 path have not been tuned.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
  per-block K stats workspace consumed by Kernel B), removing redundant
  K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
  to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
  + reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Add s_barrier after sched_barrier when K-tail and V share LDS buffer,
mirroring upstream PR #4742. Applies to both async_vsa and async_jenga pipelines.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Replace process-lifetime lazy hipMalloc K-stats workspace with a caller-owned
buffer; expose sparge_blockmap_get_workspace_size() / compute_workspace_layout()
host helpers. Split the combined sparge_blockmap_fwd into stage launchers
(sparge_kstats_fwd_oneshot + sparge_blockmap_only_fwd_oneshot) so the chained
launch is timed end-to-end.

Make pooled_k storage dtype follow KDataType (fp16/bf16) instead of fp32 to halve
workspace footprint and match dense-FMHA precision. Tighten per-head superparam
pointers to required (non-null) and assert N_k <= 256 in jenga MakeKargs to
document the 256-bool LDS staging cap. Drop the obsolete VSA extra-LDS staging.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Strip internal R-tag / phase labels (R20, R21A/B, Round 8/13f, Track F, B2.v3,
Phase 1/2/3) from comments — replace with descriptive names so future readers
don't need the change-log. Reflow long signature in fmha_fwd_trek.hpp.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Wire SpargeAttn CPU reference into test_sparge: build the block_map on host via
sparge::build_block_map_meansim and cross-check against the GPU-produced map;
self-check the VSA delta-LUT (valid count + reachable kb indices); split PASS/FAIL
into separate block_map / LUT / attention-output lines for clearer diagnosis.

Set sparge_tool::SpargeParams::BLKQ default to 64 to match SpargeAttn SM90
convention (cite upstream qk_int_sv_f8_cuda_sm90.cu:143-144); tighten bf16
tolerance back to the dense FMHA baseline (4e-2 atol, 1e-2 rtol).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…0 instantiation

Preserve the R25 Step 1 "A1 / redesign D" state before redesigning toward "B"
(per-CTA PV-skip matching upstream shipped reference). This snapshot lets us
restore A1 if the B redesign fails.

A1 redesign D pipeline (per-warp, arithmetic-only PV-skip, wrapped in
`if constexpr (kEnablePVSkip)`):
  - include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_sparge.hpp
  - include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_sparge_kernel.hpp

V0 instantiation wiring (per gino_tmp/R25/programmer/v0_instance/REPORT.md):
  - example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_sparge.py
  - example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp
  - example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp
  - example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp
  - example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py
  - example/ck_tile/50_sparse_attn/CMakeLists.txt
  - example/ck_tile/01_fmha/CMakeLists.txt
  - example/ck_tile/50_sparse_attn/test_sparge.cpp (-pv_skip_compile=0|1 CLI)

This commit excludes all *_REVIEW.{hpp,cpp} mirror files (left untracked) and
all build artefacts. _vsa.hpp / _jenga.hpp are not modified.

Tag: R25-step1-A1-paper-aligned points at this commit.
- Per-head pv_threshold via head_remap LUT (CLI: -pv_threshold_per_head);
  sentinel 1e30 routes to kEnablePVSkip=false bucket
- kEnablePVSkip bool → PVSkipMode enum {kNone, kPerWarp, kPerBlock};
  new kPerBlock matches upstream sm80 (LDS vote, V loads unconditional).
  CLI: -pv_mode={none,warp,block}, default warp
- README: PV-skip modes section + MI300X 3-curve sparsity chart

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant