Skip to content

Commit 2a723b7

Browse files
committed
feat(compute): T2.2 capture-aware allocWeight routing via cudaMallocAsync
When CaptureAwareAllocator is active (set by BeginCapture/WithCapture), allocWeight routes through cudaMallocAsync on the capture stream so allocations are recorded as graph nodes. This avoids the silent hang caused by cudaMallocManaged during CUDA graph capture on GB10. Similarly, uploadBytes routes through cudaMemcpyAsync on the capture stream instead of the synchronous CPU copy used by the managed-memory path, which is illegal during capture. The ensureNotCapturing guard now only fires when capture is active but the allocator was NOT properly switched via BeginCapture/WithCapture. Changes: - Add IsCapturing() to CaptureAwareAllocator interface - Implement IsCapturing() on cuda.MemPool and gpuapi.CUDAMemPool - Add async allocation/copy routing in allocWeight and uploadBytes - Add function variable indirections for MallocManaged, MallocAsync, and MemcpyAsync to enable CPU-mock testing - Add 7 unit tests covering all routing paths
1 parent 6efe00c commit 2a723b7

5 files changed

Lines changed: 379 additions & 6 deletions

File tree

compute/capture_alloc_test.go

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
package compute
2+
3+
import (
4+
"errors"
5+
"sync/atomic"
6+
"testing"
7+
"unsafe"
8+
9+
"github.com/zerfoo/ztensor/internal/cuda"
10+
"github.com/zerfoo/ztensor/internal/gpuapi"
11+
)
12+
13+
// --- fake CaptureAwareAllocator pool for tests ---
14+
15+
type fakeCapturePool struct {
16+
capturing bool
17+
}
18+
19+
func (p *fakeCapturePool) Alloc(int, int) (unsafe.Pointer, error) { return nil, nil }
20+
func (p *fakeCapturePool) Free(int, unsafe.Pointer, int) {}
21+
func (p *fakeCapturePool) AllocManaged(int, int) (unsafe.Pointer, error) { return nil, nil }
22+
func (p *fakeCapturePool) FreeManaged(int, unsafe.Pointer, int) {}
23+
func (p *fakeCapturePool) Drain() error { return nil }
24+
func (p *fakeCapturePool) Stats() (int, int) { return 0, 0 }
25+
func (p *fakeCapturePool) SetCaptureStream(_ unsafe.Pointer) { p.capturing = true }
26+
func (p *fakeCapturePool) ClearCaptureStream() { p.capturing = false }
27+
func (p *fakeCapturePool) IsCapturing() bool { return p.capturing }
28+
29+
var (
30+
_ gpuapi.MemPool = (*fakeCapturePool)(nil)
31+
_ gpuapi.CaptureAwareAllocator = (*fakeCapturePool)(nil)
32+
)
33+
34+
// --- fake non-capture-aware pool (like CUDAArenaPool) ---
35+
36+
type fakeBasicPool struct{}
37+
38+
func (p *fakeBasicPool) Alloc(int, int) (unsafe.Pointer, error) { return nil, nil }
39+
func (p *fakeBasicPool) Free(int, unsafe.Pointer, int) {}
40+
func (p *fakeBasicPool) AllocManaged(int, int) (unsafe.Pointer, error) { return nil, nil }
41+
func (p *fakeBasicPool) FreeManaged(int, unsafe.Pointer, int) {}
42+
func (p *fakeBasicPool) Drain() error { return nil }
43+
func (p *fakeBasicPool) Stats() (int, int) { return 0, 0 }
44+
45+
var _ gpuapi.MemPool = (*fakeBasicPool)(nil)
46+
47+
// --- test helpers ---
48+
49+
// swapMallocAsyncFn replaces the package-level mallocAsyncFn and returns
50+
// a restore closure.
51+
func swapMallocAsyncFn(fn func(int, *cuda.Stream) (unsafe.Pointer, error)) func() {
52+
prev := mallocAsyncFn
53+
mallocAsyncFn = fn
54+
return func() { mallocAsyncFn = prev }
55+
}
56+
57+
// swapMallocManagedFn replaces the package-level mallocManagedFn and returns
58+
// a restore closure.
59+
func swapMallocManagedFn(fn func(int) (unsafe.Pointer, error)) func() {
60+
prev := mallocManagedFn
61+
mallocManagedFn = fn
62+
return func() { mallocManagedFn = prev }
63+
}
64+
65+
// swapMemcpyAsyncFn replaces the package-level memcpyAsyncFn and returns
66+
// a restore closure.
67+
func swapMemcpyAsyncFn(fn func(unsafe.Pointer, unsafe.Pointer, int, cuda.MemcpyKind, *cuda.Stream) error) func() {
68+
prev := memcpyAsyncFn
69+
memcpyAsyncFn = fn
70+
return func() { memcpyAsyncFn = prev }
71+
}
72+
73+
// --- allocWeight tests ---
74+
75+
// TestAllocWeight_UsesAsyncWhenCapturing verifies that allocWeight routes
76+
// through cudaMallocAsync when CaptureAwareAllocator is active.
77+
func TestAllocWeight_UsesAsyncWhenCapturing(t *testing.T) {
78+
var asyncCalled atomic.Bool
79+
var requestedSize int
80+
var sentinel byte
81+
82+
restore := swapMallocAsyncFn(func(size int, _ *cuda.Stream) (unsafe.Pointer, error) {
83+
asyncCalled.Store(true)
84+
requestedSize = size
85+
return unsafe.Pointer(&sentinel), nil
86+
})
87+
defer restore()
88+
89+
// Also stub captureStatusFn so ensureNotCapturing does not interfere.
90+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
91+
return cuda.CaptureStatusActive, nil
92+
})
93+
defer restoreStatus()
94+
95+
pool := &fakeCapturePool{capturing: true}
96+
e := &GPUEngine[float32]{
97+
stream: fakePtrStream{},
98+
pool: pool,
99+
}
100+
101+
ptr, err := e.allocWeight(4096)
102+
if err != nil {
103+
t.Fatalf("allocWeight during capture: unexpected error: %v", err)
104+
}
105+
if !asyncCalled.Load() {
106+
t.Fatal("allocWeight during capture: expected cudaMallocAsync to be called")
107+
}
108+
if requestedSize != 4096 {
109+
t.Fatalf("allocWeight during capture: async alloc size = %d, want 4096", requestedSize)
110+
}
111+
if ptr != unsafe.Pointer(&sentinel) {
112+
t.Fatal("allocWeight during capture: returned pointer does not match async allocation")
113+
}
114+
}
115+
116+
// TestAllocWeight_UsesManagedWhenNotCapturing verifies that allocWeight
117+
// still uses cudaMallocManaged when capture is NOT active and managedMem
118+
// is true.
119+
func TestAllocWeight_UsesManagedWhenNotCapturing(t *testing.T) {
120+
var managedCalled atomic.Bool
121+
var sentinel byte
122+
123+
restoreManaged := swapMallocManagedFn(func(size int) (unsafe.Pointer, error) {
124+
managedCalled.Store(true)
125+
return unsafe.Pointer(&sentinel), nil
126+
})
127+
defer restoreManaged()
128+
129+
var asyncCalled atomic.Bool
130+
restoreAsync := swapMallocAsyncFn(func(_ int, _ *cuda.Stream) (unsafe.Pointer, error) {
131+
asyncCalled.Store(true)
132+
return nil, nil
133+
})
134+
defer restoreAsync()
135+
136+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
137+
return cuda.CaptureStatusNone, nil
138+
})
139+
defer restoreStatus()
140+
141+
pool := &fakeCapturePool{capturing: false}
142+
e := &GPUEngine[float32]{
143+
stream: fakePtrStream{},
144+
pool: pool,
145+
managedMem: true,
146+
}
147+
148+
ptr, err := e.allocWeight(4096)
149+
if err != nil {
150+
t.Fatalf("allocWeight (not capturing, managed): unexpected error: %v", err)
151+
}
152+
if !managedCalled.Load() {
153+
t.Fatal("allocWeight (not capturing, managed): expected cudaMallocManaged to be called")
154+
}
155+
if asyncCalled.Load() {
156+
t.Fatal("allocWeight (not capturing, managed): cudaMallocAsync should NOT be called")
157+
}
158+
if ptr != unsafe.Pointer(&sentinel) {
159+
t.Fatal("allocWeight (not capturing, managed): returned pointer does not match managed allocation")
160+
}
161+
}
162+
163+
// TestAllocWeight_GuardFiresWithoutCaptureAwareAllocator verifies that
164+
// ensureNotCapturing still blocks allocWeight when capture is active
165+
// but the pool does NOT implement CaptureAwareAllocator (e.g.,
166+
// CUDAArenaPool). This is the "raw capture without BeginCapture" path.
167+
func TestAllocWeight_GuardFiresWithoutCaptureAwareAllocator(t *testing.T) {
168+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
169+
return cuda.CaptureStatusActive, nil
170+
})
171+
defer restoreStatus()
172+
173+
e := &GPUEngine[float32]{
174+
stream: fakePtrStream{},
175+
pool: &fakeBasicPool{},
176+
}
177+
178+
ptr, err := e.allocWeight(4096)
179+
if err == nil {
180+
t.Fatal("allocWeight with non-capture-aware pool during capture: expected error, got nil")
181+
}
182+
if !errors.Is(err, ErrCaptureIncompatibleAllocation) {
183+
t.Fatalf("allocWeight: expected ErrCaptureIncompatibleAllocation, got %v", err)
184+
}
185+
if ptr != nil {
186+
t.Fatalf("allocWeight: expected nil pointer on guard trip, got %p", ptr)
187+
}
188+
}
189+
190+
// TestAllocWeight_GuardSkippedWhenCaptureAwareAllocatorActive verifies
191+
// that ensureNotCapturing does NOT fire when CaptureAwareAllocator is
192+
// properly engaged via BeginCapture/WithCapture.
193+
func TestAllocWeight_GuardSkippedWhenCaptureAwareAllocatorActive(t *testing.T) {
194+
var ensureNotCapturingReached atomic.Bool
195+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
196+
ensureNotCapturingReached.Store(true)
197+
return cuda.CaptureStatusActive, nil
198+
})
199+
defer restoreStatus()
200+
201+
restoreAsync := swapMallocAsyncFn(func(_ int, _ *cuda.Stream) (unsafe.Pointer, error) {
202+
var sentinel byte
203+
return unsafe.Pointer(&sentinel), nil
204+
})
205+
defer restoreAsync()
206+
207+
pool := &fakeCapturePool{capturing: true}
208+
e := &GPUEngine[float32]{
209+
stream: fakePtrStream{},
210+
pool: pool,
211+
}
212+
213+
_, err := e.allocWeight(4096)
214+
if err != nil {
215+
t.Fatalf("allocWeight with capture-aware allocator active: unexpected error: %v", err)
216+
}
217+
if ensureNotCapturingReached.Load() {
218+
t.Fatal("ensureNotCapturing should NOT be called when CaptureAwareAllocator is active")
219+
}
220+
}
221+
222+
// --- uploadBytes tests ---
223+
224+
// TestUploadBytes_UsesAsyncWhenCapturing verifies that uploadBytes routes
225+
// through cudaMemcpyAsync when CaptureAwareAllocator is active.
226+
func TestUploadBytes_UsesAsyncWhenCapturing(t *testing.T) {
227+
var asyncCalled atomic.Bool
228+
var copiedSize int
229+
var copiedKind cuda.MemcpyKind
230+
231+
restoreMemcpy := swapMemcpyAsyncFn(func(_ unsafe.Pointer, _ unsafe.Pointer, count int, kind cuda.MemcpyKind, _ *cuda.Stream) error {
232+
asyncCalled.Store(true)
233+
copiedSize = count
234+
copiedKind = kind
235+
return nil
236+
})
237+
defer restoreMemcpy()
238+
239+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
240+
return cuda.CaptureStatusActive, nil
241+
})
242+
defer restoreStatus()
243+
244+
pool := &fakeCapturePool{capturing: true}
245+
e := &GPUEngine[float32]{
246+
stream: fakePtrStream{},
247+
pool: pool,
248+
}
249+
250+
src := []byte{0x01, 0x02, 0x03, 0x04}
251+
var devMem byte
252+
err := e.uploadBytes(unsafe.Pointer(&devMem), src)
253+
if err != nil {
254+
t.Fatalf("uploadBytes during capture: unexpected error: %v", err)
255+
}
256+
if !asyncCalled.Load() {
257+
t.Fatal("uploadBytes during capture: expected cudaMemcpyAsync to be called")
258+
}
259+
if copiedSize != 4 {
260+
t.Fatalf("uploadBytes during capture: copied size = %d, want 4", copiedSize)
261+
}
262+
if copiedKind != cuda.MemcpyHostToDevice {
263+
t.Fatalf("uploadBytes during capture: copy kind = %v, want MemcpyHostToDevice", copiedKind)
264+
}
265+
}
266+
267+
// TestUploadBytes_UsesSyncWhenNotCapturing verifies that uploadBytes
268+
// falls through to the normal (non-async) path when capture is NOT active.
269+
func TestUploadBytes_UsesSyncWhenNotCapturing(t *testing.T) {
270+
var asyncCalled atomic.Bool
271+
restoreMemcpy := swapMemcpyAsyncFn(func(_ unsafe.Pointer, _ unsafe.Pointer, _ int, _ cuda.MemcpyKind, _ *cuda.Stream) error {
272+
asyncCalled.Store(true)
273+
return nil
274+
})
275+
defer restoreMemcpy()
276+
277+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
278+
return cuda.CaptureStatusNone, nil
279+
})
280+
defer restoreStatus()
281+
282+
pool := &fakeCapturePool{capturing: false}
283+
e := &GPUEngine[float32]{
284+
stream: fakePtrStream{},
285+
pool: pool,
286+
managedMem: true,
287+
}
288+
289+
// With managedMem=true and not capturing, uploadBytes does a direct CPU copy.
290+
// We can't test the actual copy without a real managed pointer, but we can
291+
// verify cudaMemcpyAsync was NOT called.
292+
src := []byte{0x01, 0x02}
293+
buf := make([]byte, 2)
294+
err := e.uploadBytes(unsafe.Pointer(&buf[0]), src)
295+
if err != nil {
296+
t.Fatalf("uploadBytes (not capturing, managed): unexpected error: %v", err)
297+
}
298+
if asyncCalled.Load() {
299+
t.Fatal("uploadBytes (not capturing, managed): cudaMemcpyAsync should NOT be called")
300+
}
301+
// Verify the sync copy worked.
302+
if buf[0] != 0x01 || buf[1] != 0x02 {
303+
t.Fatalf("uploadBytes (not capturing, managed): sync copy produced %v, want [1 2]", buf)
304+
}
305+
}
306+
307+
// TestUploadBytes_GuardFiresWithoutCaptureAwareAllocator verifies that
308+
// ensureNotCapturing still blocks uploadBytes when capture is active
309+
// but the pool does NOT implement CaptureAwareAllocator.
310+
func TestUploadBytes_GuardFiresWithoutCaptureAwareAllocator(t *testing.T) {
311+
restoreStatus := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
312+
return cuda.CaptureStatusActive, nil
313+
})
314+
defer restoreStatus()
315+
316+
e := &GPUEngine[float32]{
317+
stream: fakePtrStream{},
318+
pool: &fakeBasicPool{},
319+
}
320+
321+
src := []byte{0x01}
322+
err := e.uploadBytes(nil, src)
323+
if err == nil {
324+
t.Fatal("uploadBytes with non-capture-aware pool during capture: expected error, got nil")
325+
}
326+
if !errors.Is(err, ErrCaptureIncompatibleAllocation) {
327+
t.Fatalf("uploadBytes: expected ErrCaptureIncompatibleAllocation, got %v", err)
328+
}
329+
}

compute/gpu_engine.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,15 @@ var (
588588
graphDestroyFn = cuda.GraphDestroy
589589
)
590590

591+
// mallocManagedFn, mallocAsyncFn, and memcpyAsyncFn are indirection points
592+
// for the CUDA allocation and copy functions used by allocWeight and uploadBytes.
593+
// Tests swap them to verify capture-aware routing without real CUDA hardware.
594+
var (
595+
mallocManagedFn = cuda.MallocManaged
596+
mallocAsyncFn = cuda.MallocAsync
597+
memcpyAsyncFn = cuda.MemcpyAsync
598+
)
599+
591600
// ensureNotCapturing returns ErrCaptureIncompatibleAllocation if the
592601
// engine's stream is currently capturing a CUDA graph. On CPU-only
593602
// runtimes or when the stream handle is nil, returns nil (no capture
@@ -614,24 +623,44 @@ func (e *GPUEngine[T]) ensureNotCapturing() error {
614623

615624
// allocWeight allocates permanent memory for a weight tensor.
616625
// Uses cudaMallocManaged on devices with managed memory support,
617-
// otherwise uses cudaMalloc. Returns ErrCaptureIncompatibleAllocation
618-
// if invoked while a CUDA graph capture is active on the engine's stream.
626+
// otherwise uses cudaMalloc.
627+
//
628+
// When CaptureAwareAllocator is active (set by BeginCapture/WithCapture),
629+
// allocations route through cudaMallocAsync on the capture stream so they
630+
// are recorded as graph nodes. This avoids the silent hang caused by
631+
// cudaMallocManaged during CUDA graph capture on GB10.
632+
//
633+
// Returns ErrCaptureIncompatibleAllocation only if capture is active but
634+
// the allocator was NOT properly switched via BeginCapture/WithCapture.
619635
func (e *GPUEngine[T]) allocWeight(byteSize int) (unsafe.Pointer, error) {
636+
if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok && cap.IsCapturing() {
637+
s := cuda.StreamFromPtr(e.Stream())
638+
return mallocAsyncFn(byteSize, s)
639+
}
620640
if err := e.ensureNotCapturing(); err != nil {
621641
return nil, err
622642
}
623643
if e.managedMem {
624-
return cuda.MallocManaged(byteSize)
644+
return mallocManagedFn(byteSize)
625645
}
626646
return e.runtime.Malloc(byteSize)
627647
}
628648

629649
// uploadBytes copies src bytes into a device (or managed) pointer.
630650
// With managed memory, this is a direct CPU memcpy (no H2D needed).
631-
// Without managed memory, this uses cudaMemcpy H2D. Returns
632-
// ErrCaptureIncompatibleAllocation if invoked while a CUDA graph capture
633-
// is active on the engine's stream.
651+
// Without managed memory, this uses cudaMemcpy H2D.
652+
//
653+
// When CaptureAwareAllocator is active, uses cudaMemcpyAsync on the
654+
// capture stream so the copy is recorded as a graph node. The synchronous
655+
// CPU copy used by the managed-memory path is illegal during capture.
656+
//
657+
// Returns ErrCaptureIncompatibleAllocation only if capture is active but
658+
// the allocator was NOT properly switched via BeginCapture/WithCapture.
634659
func (e *GPUEngine[T]) uploadBytes(devPtr unsafe.Pointer, src []byte) error {
660+
if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok && cap.IsCapturing() {
661+
s := cuda.StreamFromPtr(e.Stream())
662+
return memcpyAsyncFn(devPtr, unsafe.Pointer(&src[0]), len(src), cuda.MemcpyHostToDevice, s)
663+
}
635664
if err := e.ensureNotCapturing(); err != nil {
636665
return err
637666
}

0 commit comments

Comments
 (0)