Skip to content

Commit 9f9eb5c

Browse files
committed
feat(compute): T2.3 pre-allocate workspace buffers at UploadWeights to avoid capture-time alloc
Add preAllocateWorkspaces() that eagerly initializes the FP8 scratchpad (scaleOne pointer + struct) and cuBLASLt handle at the end of UploadWeights, before any CUDA graph capture region begins. These two objects previously used lazy initialization (getFP8Scratch, getLtHandle) which triggered cudaMalloc on first use -- hanging silently on GB10 when first use happened inside capture. Also add captureAllocCount atomic counter to track allocWeight attempts during active capture. EndCapture resets the counter and logs a warning if non-zero. CaptureAllocCount() exposes the counter for testing.
1 parent 2a723b7 commit 9f9eb5c

2 files changed

Lines changed: 262 additions & 0 deletions

File tree

compute/gpu_engine.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ type GPUEngine[T tensor.Numeric] struct {
8787
// when cuBLAS receives very large matrices (e.g., 128256x4096 LM head).
8888
// Default: DefaultMaxAllocBytes (4 GB).
8989
maxAllocBytes int64
90+
91+
// captureAllocCount tracks allocWeight calls that occur during an active
92+
// CUDA graph capture. A properly pre-allocated workload should see zero.
93+
// Incremented atomically in allocWeight when capture is detected;
94+
// checked and reset in EndCapture.
95+
captureAllocCount atomic.Int64
9096
}
9197

9298
// NewGPUEngine creates a new GPUEngine backed by CUDA via the GRAL abstraction.
@@ -570,6 +576,10 @@ func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) e
570576
"device", fmt.Sprintf("%d", e.deviceID),
571577
"method", method)
572578
}
579+
// Pre-allocate all workspace buffers that would otherwise be lazily
580+
// initialized on first use. This ensures no cudaMalloc occurs inside
581+
// a subsequent CUDA graph capture region.
582+
e.preAllocateWorkspaces()
573583
return nil
574584
}
575585

@@ -638,6 +648,7 @@ func (e *GPUEngine[T]) allocWeight(byteSize int) (unsafe.Pointer, error) {
638648
return mallocAsyncFn(byteSize, s)
639649
}
640650
if err := e.ensureNotCapturing(); err != nil {
651+
e.captureAllocCount.Add(1)
641652
return nil, err
642653
}
643654
if e.managedMem {
@@ -714,6 +725,10 @@ func (e *GPUEngine[T]) EndCapture() (GraphHandle, error) {
714725
if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok {
715726
defer cap.ClearCaptureStream()
716727
}
728+
if n := e.captureAllocCount.Swap(0); n > 0 {
729+
e.logger.Warn("allocWeight called during capture",
730+
"count", fmt.Sprintf("%d", n))
731+
}
717732
s := cuda.StreamFromPtr(e.Stream())
718733
graph, err := streamEndCaptureFn(s)
719734
if err != nil {
@@ -819,6 +834,48 @@ func (e *GPUEngine[T]) Close() error {
819834
return firstErr
820835
}
821836

837+
// CaptureAllocCount returns the cumulative number of allocWeight calls that
838+
// were attempted while a CUDA graph capture was active. A properly
839+
// pre-allocated workload should observe zero after EndCapture.
840+
func (e *GPUEngine[T]) CaptureAllocCount() int64 {
841+
return e.captureAllocCount.Load()
842+
}
843+
844+
// preAllocateWorkspaces eagerly initializes all lazy-allocated workspace
845+
// buffers so that no cudaMalloc occurs inside a CUDA graph capture region.
846+
// Called at the end of UploadWeights, after all weight tensors are on GPU.
847+
//
848+
// For dense float32 workloads, pool.Alloc (arena-backed) is capture-safe via
849+
// CaptureAwareAllocator, but objects allocated outside the arena — the FP8
850+
// scratchpad and the cuBLASLt handle — use cudaMalloc and would hang if first
851+
// touched during capture on GB10.
852+
func (e *GPUEngine[T]) preAllocateWorkspaces() {
853+
// 1. FP8 scratchpad: allocate scaleOne and the struct itself so that the
854+
// first FP8 MatMul during capture does not trigger cudaMalloc.
855+
if e.fp8Scratch == nil {
856+
if s, err := e.getFP8Scratch(); err != nil {
857+
e.logger.Warn("preAllocateWorkspaces: FP8 scratchpad init failed",
858+
"error", err.Error())
859+
} else {
860+
_ = s // assigned to e.fp8Scratch inside getFP8Scratch
861+
}
862+
}
863+
864+
// 2. cuBLASLt handle: cublasLtCreate allocates internal CUDA state.
865+
if e.ltHandle == nil {
866+
if h, err := e.getLtHandle(); err != nil {
867+
e.logger.Warn("preAllocateWorkspaces: cuBLASLt handle init failed",
868+
"error", err.Error())
869+
} else {
870+
_ = h // assigned to e.ltHandle inside getLtHandle
871+
}
872+
}
873+
874+
e.logger.Info("workspace buffers pre-allocated",
875+
"fp8Scratch", fmt.Sprintf("%v", e.fp8Scratch != nil),
876+
"ltHandle", fmt.Sprintf("%v", e.ltHandle != nil))
877+
}
878+
822879
// OOMFallbackCount returns the number of times GPU OOM triggered CPU fallback.
823880
func (e *GPUEngine[T]) OOMFallbackCount() int64 {
824881
return e.oomFallbackCount.Load()

compute/workspace_prealloc_test.go

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
package compute
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/zerfoo/ztensor/internal/cuda"
8+
"github.com/zerfoo/ztensor/log"
9+
"github.com/zerfoo/ztensor/numeric"
10+
"github.com/zerfoo/ztensor/tensor"
11+
)
12+
13+
// TestPreAllocateWorkspaces_FP8ScratchInitialized verifies that after
14+
// UploadWeights, the FP8 scratchpad is non-nil (eagerly initialized).
15+
func TestPreAllocateWorkspaces_FP8ScratchInitialized(t *testing.T) {
16+
eng := newPreallocEngine(t)
17+
if eng.fp8Scratch != nil {
18+
t.Fatal("precondition: fp8Scratch should be nil before UploadWeights")
19+
}
20+
21+
if err := eng.UploadWeights(nil); err != nil {
22+
t.Fatalf("UploadWeights: %v", err)
23+
}
24+
25+
if eng.fp8Scratch == nil {
26+
t.Fatal("fp8Scratch should be non-nil after UploadWeights")
27+
}
28+
if eng.fp8Scratch.scaleOne == nil {
29+
t.Fatal("fp8Scratch.scaleOne should be non-nil after pre-allocation")
30+
}
31+
}
32+
33+
// TestPreAllocateWorkspaces_CalledByUploadWeights verifies that
34+
// preAllocateWorkspaces fires at the end of UploadWeights even when
35+
// called with an empty weight list (the pre-allocation is unconditional).
36+
func TestPreAllocateWorkspaces_CalledByUploadWeights(t *testing.T) {
37+
eng := newPreallocEngine(t)
38+
39+
if err := eng.UploadWeights([]*tensor.TensorNumeric[float32]{}); err != nil {
40+
t.Fatalf("UploadWeights: %v", err)
41+
}
42+
43+
if eng.fp8Scratch == nil {
44+
t.Fatal("fp8Scratch should be non-nil after UploadWeights")
45+
}
46+
if eng.fp8Scratch.scaleOne == nil {
47+
t.Fatal("fp8Scratch.scaleOne should be non-nil after pre-allocation")
48+
}
49+
}
50+
51+
// TestPreAllocateWorkspaces_TableDriven exercises workspace pre-allocation
52+
// with varying weight list sizes. Pre-allocation is unconditional, so
53+
// fp8Scratch should be non-nil regardless of weight count.
54+
func TestPreAllocateWorkspaces_TableDriven(t *testing.T) {
55+
tests := []struct {
56+
name string
57+
numWeights int
58+
}{
59+
{name: "no weights", numWeights: 0},
60+
{name: "one nil entry", numWeights: 1},
61+
{name: "three nil entries", numWeights: 3},
62+
}
63+
64+
for _, tt := range tests {
65+
t.Run(tt.name, func(t *testing.T) {
66+
eng := newPreallocEngine(t)
67+
pool := eng.pool.(*fakeMemPool)
68+
69+
// Pass nil tensor entries -- UploadWeights skips them.
70+
weights := make([]*tensor.TensorNumeric[float32], tt.numWeights)
71+
if err := eng.UploadWeights(weights); err != nil {
72+
t.Fatalf("UploadWeights: %v", err)
73+
}
74+
75+
if eng.fp8Scratch == nil {
76+
t.Error("fp8Scratch should be non-nil after UploadWeights")
77+
}
78+
if eng.fp8Scratch.scaleOne == nil {
79+
t.Error("fp8Scratch.scaleOne should be non-nil")
80+
}
81+
// scaleOne alloc is the minimum: 1 pool.Alloc from getFP8Scratch.
82+
if pool.allocCount < 1 {
83+
t.Errorf("expected at least 1 alloc from pre-allocation, got %d", pool.allocCount)
84+
}
85+
})
86+
}
87+
}
88+
89+
// TestCaptureAllocCount_ZeroAfterPrealloc verifies that captureAllocCount
90+
// stays at zero when allocWeight is not called during capture. This is the
91+
// expected state for a properly pre-allocated workload.
92+
func TestCaptureAllocCount_ZeroAfterPrealloc(t *testing.T) {
93+
eng := newPreallocEngine(t)
94+
if err := eng.UploadWeights(nil); err != nil {
95+
t.Fatalf("UploadWeights: %v", err)
96+
}
97+
98+
if got := eng.CaptureAllocCount(); got != 0 {
99+
t.Fatalf("CaptureAllocCount after UploadWeights: got %d, want 0", got)
100+
}
101+
}
102+
103+
// TestCaptureAllocCount_IncrementsOnCaptureTimeAlloc verifies that
104+
// allocWeight increments captureAllocCount when capture is active.
105+
func TestCaptureAllocCount_IncrementsOnCaptureTimeAlloc(t *testing.T) {
106+
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
107+
return cuda.CaptureStatusActive, nil
108+
})
109+
defer restore()
110+
111+
eng := &GPUEngine[float32]{stream: fakePtrStream{}}
112+
113+
// First attempt — should fail with capture sentinel and increment counter.
114+
_, err := eng.allocWeight(4096)
115+
if !errors.Is(err, ErrCaptureIncompatibleAllocation) {
116+
t.Fatalf("allocWeight: expected ErrCaptureIncompatibleAllocation, got %v", err)
117+
}
118+
119+
if got := eng.CaptureAllocCount(); got != 1 {
120+
t.Fatalf("CaptureAllocCount after 1 attempt: got %d, want 1", got)
121+
}
122+
123+
// Second attempt — count should increase.
124+
_, _ = eng.allocWeight(8192)
125+
if got := eng.CaptureAllocCount(); got != 2 {
126+
t.Fatalf("CaptureAllocCount after 2 attempts: got %d, want 2", got)
127+
}
128+
}
129+
130+
// TestCaptureAllocCount_ResetByEndCapture verifies that EndCapture resets
131+
// the captureAllocCount to zero after logging.
132+
func TestCaptureAllocCount_ResetByEndCapture(t *testing.T) {
133+
// Arrange: inject a capture-active status for allocWeight, then swap to
134+
// a non-capture status for EndCapture.
135+
captureActive := true
136+
restore := swapCaptureStatusFn(func(_ *cuda.Stream) (cuda.CaptureStatus, error) {
137+
if captureActive {
138+
return cuda.CaptureStatusActive, nil
139+
}
140+
return cuda.CaptureStatusNone, nil
141+
})
142+
defer restore()
143+
144+
eng := &GPUEngine[float32]{
145+
stream: fakePtrStream{},
146+
logger: log.Nop(),
147+
}
148+
149+
// Trigger two allocWeight attempts during capture.
150+
_, _ = eng.allocWeight(4096)
151+
_, _ = eng.allocWeight(8192)
152+
if got := eng.CaptureAllocCount(); got != 2 {
153+
t.Fatalf("CaptureAllocCount before EndCapture: got %d, want 2", got)
154+
}
155+
156+
// EndCapture will fail (no real graph) but should still reset the counter.
157+
captureActive = false
158+
oldEnd := streamEndCaptureFn
159+
streamEndCaptureFn = func(_ *cuda.Stream) (*cuda.Graph, error) {
160+
return nil, errors.New("synthetic: no graph")
161+
}
162+
defer func() { streamEndCaptureFn = oldEnd }()
163+
164+
_, _ = eng.EndCapture()
165+
166+
if got := eng.CaptureAllocCount(); got != 0 {
167+
t.Fatalf("CaptureAllocCount after EndCapture: got %d, want 0", got)
168+
}
169+
}
170+
171+
// TestPreAllocateWorkspaces_Idempotent verifies that calling
172+
// preAllocateWorkspaces multiple times does not leak or double-allocate.
173+
func TestPreAllocateWorkspaces_Idempotent(t *testing.T) {
174+
eng := newPreallocEngine(t)
175+
pool := eng.pool.(*fakeMemPool)
176+
177+
eng.preAllocateWorkspaces()
178+
allocsAfterFirst := pool.allocCount
179+
180+
eng.preAllocateWorkspaces()
181+
allocsAfterSecond := pool.allocCount
182+
183+
if allocsAfterSecond != allocsAfterFirst {
184+
t.Fatalf("second preAllocateWorkspaces caused %d new allocs, want 0",
185+
allocsAfterSecond-allocsAfterFirst)
186+
}
187+
}
188+
189+
// newPreallocEngine builds a GPUEngine suitable for testing workspace
190+
// pre-allocation without real CUDA hardware.
191+
func newPreallocEngine(t *testing.T) *GPUEngine[float32] {
192+
t.Helper()
193+
pool := newFakeMemPool()
194+
return &GPUEngine[float32]{
195+
cpu: NewCPUEngine[float32](numeric.Float32Ops{}),
196+
runtime: fakeRuntime{},
197+
pool: pool,
198+
stream: fakeStream{},
199+
logger: log.Nop(),
200+
deviceID: 0,
201+
dtype: DTypeF32,
202+
maxAllocBytes: DefaultMaxAllocBytes,
203+
}
204+
}
205+

0 commit comments

Comments
 (0)