@@ -115,6 +115,44 @@ void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_
115115 }
116116}
117117
118+ TEST (MatMulBnb4, RejectsUndersizedBQuantTensor) {
119+ // K=32, N=2 → numel=64, expected b_quant size = (64+1)/2 = 32
120+ // Provide only 4 bytes (valid for K=4, N=2) but claim K=32, N=2
121+ OpTester test (" MatMulBnb4" , 1 , kMSDomain );
122+ test.AddAttribute <int64_t >(" K" , 32LL );
123+ test.AddAttribute <int64_t >(" N" , 2LL );
124+ test.AddAttribute <int64_t >(" block_size" , 32LL );
125+ test.AddAttribute <int64_t >(" quant_type" , 1LL ); // NF4
126+
127+ test.AddInput <float >(" A" , {1 , 32 }, std::vector<float >(32 , 0 .0f ));
128+ test.AddInput <uint8_t >(" B" , {4 }, std::vector<uint8_t >(4 , 0 )); // too small
129+ test.AddInput <float >(" absmax" , {2 }, std::vector<float >(2 , 1 .0f ));
130+ test.AddOutput <float >(" Y" , {1 , 2 }, std::vector<float >(2 , 0 .0f ));
131+
132+ test.Run (OpTester::ExpectResult::kExpectFailure , " b_quant tensor size" );
133+ }
134+
135+ TEST (MatMulBnb4, RejectsUndersizedAbsmaxTensor) {
136+ // K=32, N=2, block_size=32 → numel=64, expected absmax size = (64+32-1)/32 = 2
137+ // Provide only 1 absmax element
138+ int64_t K = 32 , N = 2 , block_size = 32 ;
139+ int64_t numel = K * N;
140+ int64_t quantized_numel = (numel + 1 ) / 2 ;
141+
142+ OpTester test (" MatMulBnb4" , 1 , kMSDomain );
143+ test.AddAttribute <int64_t >(" K" , K);
144+ test.AddAttribute <int64_t >(" N" , N);
145+ test.AddAttribute <int64_t >(" block_size" , block_size);
146+ test.AddAttribute <int64_t >(" quant_type" , 1LL ); // NF4
147+
148+ test.AddInput <float >(" A" , {1 , K}, std::vector<float >(K, 0 .0f ));
149+ test.AddInput <uint8_t >(" B" , {quantized_numel}, std::vector<uint8_t >(quantized_numel, 0 ));
150+ test.AddInput <float >(" absmax" , {1 }, std::vector<float >(1 , 1 .0f )); // too small
151+ test.AddOutput <float >(" Y" , {1 , N}, std::vector<float >(N, 0 .0f ));
152+
153+ test.Run (OpTester::ExpectResult::kExpectFailure , " absmax tensor size" );
154+ }
155+
118156TEST (MatMulBnb4, DISABLED_Float32) {
119157 for (auto qt : {0 , 1 }) {
120158 for (auto M : {1 , 2 , 100 }) {
0 commit comments