Skip to content

Commit 3a8d28f

Browse files
vrasparCopilot
andcommitted
Add tensor size validation for MatMulBnb4 to prevent OOB read via K/N attribute mismatch
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent aaa4944 commit 3a8d28f

2 files changed

Lines changed: 59 additions & 0 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class MatMulBnb4 final : public OpKernel {
1919
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
2020
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
2121
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
22+
ORT_ENFORCE(K_ > 0, "K must be positive, got ", K_);
23+
ORT_ENFORCE(N_ > 0, "N must be positive, got ", N_);
24+
ORT_ENFORCE(block_size_ > 0, "block_size must be positive, got ", block_size_);
2225
ORT_ENFORCE(
2326
quant_type_ == FP4 || quant_type_ == NF4,
2427
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
@@ -50,6 +53,24 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
5053
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
5154
const float* absmax_data = absmax->Data<float>();
5255

56+
const int64_t numel = K_ * N_;
57+
const int64_t expected_b_quant_size = (numel + 1) / 2;
58+
const int64_t expected_absmax_size = (numel + block_size_ - 1) / block_size_;
59+
60+
if (b_quant->Shape().Size() < expected_b_quant_size) {
61+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
62+
"b_quant tensor size (", b_quant->Shape().Size(),
63+
") is too small for K=", K_, " and N=", N_,
64+
". Expected at least ", expected_b_quant_size, " elements.");
65+
}
66+
if (absmax->Shape().Size() < expected_absmax_size) {
67+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
68+
"absmax tensor size (", absmax->Shape().Size(),
69+
") is too small for K=", K_, ", N=", N_,
70+
", block_size=", block_size_,
71+
". Expected at least ", expected_absmax_size, " elements.");
72+
}
73+
5374
AllocatorPtr allocator;
5475
auto status = ctx->GetTempSpaceAllocator(&allocator);
5576
ORT_RETURN_IF_ERROR(status);

onnxruntime/test/contrib_ops/matmul_bnb4_test.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,44 @@ void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_
115115
}
116116
}
117117

118+
TEST(MatMulBnb4, RejectsUndersizedBQuantTensor) {
119+
// K=32, N=2 → numel=64, expected b_quant size = (64+1)/2 = 32
120+
// Provide only 4 bytes (valid for K=4, N=2) but claim K=32, N=2
121+
OpTester test("MatMulBnb4", 1, kMSDomain);
122+
test.AddAttribute<int64_t>("K", 32LL);
123+
test.AddAttribute<int64_t>("N", 2LL);
124+
test.AddAttribute<int64_t>("block_size", 32LL);
125+
test.AddAttribute<int64_t>("quant_type", 1LL); // NF4
126+
127+
test.AddInput<float>("A", {1, 32}, std::vector<float>(32, 0.0f));
128+
test.AddInput<uint8_t>("B", {4}, std::vector<uint8_t>(4, 0)); // too small
129+
test.AddInput<float>("absmax", {2}, std::vector<float>(2, 1.0f));
130+
test.AddOutput<float>("Y", {1, 2}, std::vector<float>(2, 0.0f));
131+
132+
test.Run(OpTester::ExpectResult::kExpectFailure, "b_quant tensor size");
133+
}
134+
135+
TEST(MatMulBnb4, RejectsUndersizedAbsmaxTensor) {
136+
// K=32, N=2, block_size=32 → numel=64, expected absmax size = (64+32-1)/32 = 2
137+
// Provide only 1 absmax element
138+
int64_t K = 32, N = 2, block_size = 32;
139+
int64_t numel = K * N;
140+
int64_t quantized_numel = (numel + 1) / 2;
141+
142+
OpTester test("MatMulBnb4", 1, kMSDomain);
143+
test.AddAttribute<int64_t>("K", K);
144+
test.AddAttribute<int64_t>("N", N);
145+
test.AddAttribute<int64_t>("block_size", block_size);
146+
test.AddAttribute<int64_t>("quant_type", 1LL); // NF4
147+
148+
test.AddInput<float>("A", {1, K}, std::vector<float>(K, 0.0f));
149+
test.AddInput<uint8_t>("B", {quantized_numel}, std::vector<uint8_t>(quantized_numel, 0));
150+
test.AddInput<float>("absmax", {1}, std::vector<float>(1, 1.0f)); // too small
151+
test.AddOutput<float>("Y", {1, N}, std::vector<float>(N, 0.0f));
152+
153+
test.Run(OpTester::ExpectResult::kExpectFailure, "absmax tensor size");
154+
}
155+
118156
TEST(MatMulBnb4, DISABLED_Float32) {
119157
for (auto qt : {0, 1}) {
120158
for (auto M : {1, 2, 100}) {

0 commit comments

Comments
 (0)