Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
abacbe5
code init
pggPL Feb 5, 2026
ca67a05
fix
pggPL Feb 6, 2026
55c9fd5
code drop
pggPL Feb 19, 2026
09bb7ea
Remove redundant nvte_set/get_grouped_tensor_swizzled_scales
pggPL Mar 16, 2026
63df695
Merge gitlab/main into grouped_gemm_nvfp4_and_hopper
pggPL Mar 17, 2026
44ab70d
Add Hopper support for grouped GEMM and refactor cuBLAS version checks
pggPL Mar 17, 2026
3689c10
Add NVFP4 support for discrete-input grouped GEMM and skip FP8 tensor…
pggPL Mar 18, 2026
d6d26bc
Add alignment assertions for MXFP8/NVFP4 scale offsets in grouped GEM…
pggPL Mar 18, 2026
c3ba64b
Merge remote-tracking branch 'origin/main' into grouped_gemm_nvfp4_an…
pggPL Apr 26, 2026
8bdd739
Fix grouped GEMM: NVFP4 columnwise transa=N + relax MXFP8 alignment f…
pggPL Apr 26, 2026
eba3468
Clarify swap_dims comment in build_grouped_gemm_multi_inputA_args
pggPL Apr 26, 2026
4375cf2
Merge remote-tracking branch 'upstream/main' into grouped_gemm_nvfp4_…
pggPL May 8, 2026
6c7a515
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2026
526a04a
Fix grouped GEMM scale_inv offsets for NVFP4 and FP8 block scaling
pggPL May 8, 2026
0f49cc3
Relax NVFP4 amax contiguity; consolidate scale_inv offset helpers; te…
pggPL May 11, 2026
ce342dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2026
a4df7bd
Remove unused float_size in GroupedGemmSetupWorkspace::from_buffers
pggPL May 11, 2026
d6a1597
Fix Hopper grouped GEMM alpha beta handling
pggPL Mar 16, 2026
d71d614
fix
pggPL Mar 16, 2026
b86fc7e
Address code review: NVFP4 amax check, swap_dims default, test refactor
pggPL May 11, 2026
59b90b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2026
7b1c8aa
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
pggPL May 11, 2026
3f62928
Address code review feedback (#2971)
pggPL May 15, 2026
9c46760
Simplify make_*_operand signatures to take use_rowwise directly
pggPL May 15, 2026
6ad2b95
Test infrastructure fixes for NVFP4 columnwise + skip unsupported combos
pggPL May 15, 2026
77135c1
Trim verbose comments in grouped GEMM test
pggPL May 15, 2026
5e35141
Fix FP8 block scaling NT/NN failures on Hopper
pggPL May 15, 2026
3ec190a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
3657dfc
Use logical grouped GEMM shapes
pggPL May 15, 2026
a099bdb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
8a19fa2
Fix mixed FP8 block grouped GEMM scaling
pggPL May 15, 2026
71df120
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
ea7e43e
Trim unsupported FP8 block grouped GEMM tests
pggPL May 15, 2026
e9a8ac8
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
vthumbe1503 May 18, 2026
6702e89
Address review comment
vthumbe1503 May 20, 2026
1a976a8
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
vthumbe1503 May 20, 2026
2ebc9fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2026
49a9550
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
vthumbe1503 May 20, 2026
b2c64b0
Fix tests to relax alignment requirements for swizzling tests
vthumbe1503 May 21, 2026
f43dfac
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
vthumbe1503 May 21, 2026
106c00a
Fix alignment calculation in required_setup_size
vthumbe1503 May 21, 2026
d495796
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
13d76ca
use nvfp4 alpha only for nvfp4
vthumbe1503 May 21, 2026
8cdc0f8
Merge branch 'main' into grouped_gemm_nvfp4_and_hopper
vthumbe1503 May 22, 2026
681d45f
Clean up grouped GEMM tests and document alignment
pggPL May 25, 2026
6c476d4
Merge remote-tracking branch 'upstream/main' into grouped_gemm_nvfp4_…
pggPL May 25, 2026
0f7f255
Post-merge cleanup in grouped GEMM helpers and swizzle tests
pggPL May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
943 changes: 466 additions & 477 deletions tests/cpp/operator/test_grouped_gemm.cu

Large diffs are not rendered by default.

208 changes: 146 additions & 62 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ Tensor::Tensor(const std::string& name,
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
case NVTE_BLOCK_SCALING_2D:
case NVTE_NVFP4_1D_SCALING: {
// Column-wise data shape is transposed
if (shape.ndim > 0) {
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
Expand All @@ -325,8 +326,7 @@ Tensor::Tensor(const std::string& name,
}
break;
}
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING: {
case NVTE_MXFP8_1D_SCALING: {
// Column-wise data matches shape
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
Expand Down Expand Up @@ -1072,13 +1072,18 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const bool has_columnwise = tensors[0]->columnwise();
NVTE_CHECK(has_rowwise || has_columnwise, "Tensors must have at least one data layout.");

const NVTEShape shape = has_rowwise ? tensors[0]->rowwise_shape()
: tensors[0]->columnwise_shape();
const DType dtype = tensors[0]->dtype();
const size_t num_tensors = tensors.size();
const size_t elem_size = typeToNumBits(dtype) / 8;
const size_t bits_per_elem = typeToNumBits(dtype);
const bool is_sub_byte = (bits_per_elem < 8);
const size_t elem_size = is_sub_byte ? 0 : bits_per_elem / 8;
GroupedBuffers grouped;
grouped.elem_size = elem_size;
grouped.elem_size = elem_size; // Only used for D output extraction (always >= 1 byte dtype)

// Helper: convert element count to byte count (handles sub-byte types like FP4)
auto elems_to_bytes = [bits_per_elem](int64_t elems) -> size_t {
return static_cast<size_t>((elems * static_cast<int64_t>(bits_per_elem)) / 8);
};
grouped.num_tensors = num_tensors;
grouped.dtype = dtype;
grouped.scaling_mode = scaling_mode;
Expand All @@ -1088,12 +1093,13 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
std::vector<int64_t> first_dims(num_tensors);
std::vector<int64_t> last_dims(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
const auto s = has_rowwise ? tensors[i]->rowwise_shape()
: tensors[i]->columnwise_shape();
const auto s = tensors[i]->shape();
NVTE_CHECK(s.ndim == 2, "Grouped tensor build expects 2D tensors.");
first_dims[i] = static_cast<int64_t>(s.data[0]);
last_dims[i] = static_cast<int64_t>(s.data[1]);
grouped.tensor_bytes[i] = bytes(s, dtype);
const auto storage_shape = has_rowwise ? tensors[i]->rowwise_shape()
: tensors[i]->columnwise_shape();
grouped.tensor_bytes[i] = bytes(storage_shape, dtype);
}

const bool same_first = std::all_of(first_dims.begin(), first_dims.end(),
Expand All @@ -1107,9 +1113,14 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
// cuBLAS requires aligned pointers for vectorized loads
static std::mt19937 gen(12345);
std::uniform_int_distribution<int64_t> dist(0, 3);
// Calculate elements needed for 16-byte alignment in bytes, rounded up
const size_t align_elements =
std::max<size_t>(1, (16 + elem_size - 1) / elem_size); // 16 bytes / element_size
// Calculate elements needed for 16-byte alignment
size_t align_elements;
if (is_sub_byte) {
// Sub-byte types (e.g. FP4): 16 bytes = 16*8/bits_per_elem elements
align_elements = (16 * 8) / bits_per_elem;
} else {
align_elements = std::max<size_t>(1, (16 + elem_size - 1) / elem_size);
}
return dist(gen) * static_cast<int64_t>(align_elements);
};

Expand Down Expand Up @@ -1157,7 +1168,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
const int64_t total_elems = need_offsets
? (offsets[last_idx] + numel(last_idx))
: (logical_first * logical_last);
const size_t total_bytes = static_cast<size_t>(total_elems) * elem_size;
const size_t total_bytes = elems_to_bytes(total_elems);

NVTEGroupedTensor h = grouped.handle.get();

Expand All @@ -1167,8 +1178,8 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
if (has_rowwise) {
grouped.data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes,
const size_t offset_bytes_i = elems_to_bytes(offsets[i]);
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.data.get()) + offset_bytes_i,
tensors[i]->rowwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
Expand All @@ -1181,8 +1192,8 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
if (has_columnwise) {
grouped.columnwise_data = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
const size_t offset_bytes = static_cast<size_t>(offsets[i]) * elem_size;
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.columnwise_data.get()) + offset_bytes,
const size_t offset_bytes_i = elems_to_bytes(offsets[i]);
NVTE_CHECK_CUDA(cudaMemcpy(static_cast<char*>(grouped.columnwise_data.get()) + offset_bytes_i,
tensors[i]->columnwise_dptr(),
grouped.tensor_bytes[i],
cudaMemcpyDeviceToDevice));
Expand Down Expand Up @@ -1221,6 +1232,33 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
}

// Shared gather of per-tensor scale_inv buffers into a contiguous device buffer.
// Returns (device buffer, total element count). Used by all block-scaling recipes
// (MXFP8 / NVFP4 / FP8 block) — they only differ in element size and CPU getter.
auto gather_scale_inv = [&](size_t bytes_per_elem, auto get_shape_fn,
auto get_cpu_ptr_fn) -> std::pair<CudaPtr<>, size_t> {
size_t total_elems = 0;
std::vector<size_t> elem_offsets(num_tensors);
std::vector<size_t> numels(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
elem_offsets[i] = total_elems;
const NVTEShape sshape = get_shape_fn(tensors[i]);
size_t numel = 1;
for (size_t d = 0; d < sshape.ndim; ++d) numel *= sshape.data[d];
numels[i] = numel;
total_elems += numel;
}
CudaPtr<> buffer = cuda_alloc(total_elems * bytes_per_elem);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + elem_offsets[i] * bytes_per_elem;
NVTE_CHECK_CUDA(cudaMemcpy(dst, get_cpu_ptr_fn(tensors[i]),
numels[i] * bytes_per_elem, cudaMemcpyHostToDevice));
}
return {std::move(buffer), total_elems};
};

if (isFp8Type(dtype) && scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
// FP8 tensor scaling: one float scale_inv per tensor
// For delayed scaling, rowwise and columnwise share the same scale
Expand All @@ -1243,67 +1281,113 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
} else if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
// MXFP8: E8M0 scale_inv per block of 32 elements
// Helper to gather scale_inv from individual tensors into a contiguous buffer
auto gather_scales = [&](
auto get_shape_fn,
auto get_cpu_ptr_fn) -> std::pair<CudaPtr<>, size_t> {
// Compute total size and offsets
size_t total_bytes = 0;
std::vector<size_t> scale_offsets(num_tensors);
std::vector<size_t> numels(num_tensors);

for (size_t i = 0; i < num_tensors; ++i) {
scale_offsets[i] = total_bytes;
const NVTEShape shape = get_shape_fn(tensors[i]);
size_t numel = 1;
for (size_t d = 0; d < shape.ndim; ++d) {
numel *= shape.data[d];
}
numels[i] = numel;
total_bytes += numel; // E8M0 is 1 byte per element
}

// Allocate and copy
CudaPtr<> buffer = cuda_alloc(total_bytes);
for (size_t i = 0; i < num_tensors; ++i) {
tensors[i]->to_cpu();
NVTE_CHECK_CUDA(cudaGetLastError());
void* dst = static_cast<char*>(buffer.get()) + scale_offsets[i];
const void* src = get_cpu_ptr_fn(tensors[i]);
NVTE_CHECK_CUDA(cudaMemcpy(dst, src, numels[i], cudaMemcpyHostToDevice));
}
return {std::move(buffer), total_bytes};
};

// Gather rowwise scale_inv if available
// MXFP8: E8M0 scale_inv per block of 32 elements (1 byte per scale element).
if (has_rowwise) {
auto [row_buffer, row_total] = gather_scales(
auto [row_buffer, row_total] = gather_scale_inv(
/*bytes_per_elem=*/1,
[](Tensor* t) { return t->rowwise_scale_inv_shape(); },
[](Tensor* t) { return t->rowwise_cpu_scale_inv_ptr<uint8_t>(); });
[](Tensor* t) -> const void* { return t->rowwise_cpu_scale_inv_ptr<uint8_t>(); });
grouped.scale_inv = std::move(row_buffer);

NVTEShape row_shape = nvte_make_shape(&row_total, 1);
NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E8M0, row_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor));
}

// Gather columnwise scale_inv if available
if (has_columnwise) {
auto [col_buffer, col_total] = gather_scales(
auto [col_buffer, col_total] = gather_scale_inv(
/*bytes_per_elem=*/1,
[](Tensor* t) { return t->columnwise_scale_inv_shape(); },
[](Tensor* t) { return t->columnwise_cpu_scale_inv_ptr<uint8_t>(); });
[](Tensor* t) -> const void* { return t->columnwise_cpu_scale_inv_ptr<uint8_t>(); });
grouped.columnwise_scale_inv = std::move(col_buffer);

NVTEShape col_shape = nvte_make_shape(&col_total, 1);
NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E8M0, col_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor));
}

// Mark as having swizzled scales (required for GEMM)
const uint8_t swizzled = 1;
nvte_set_grouped_tensor_param(h, kNVTEGroupedWithGEMMSwizzledScales, &swizzled,
sizeof(swizzled));
} else if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// FP8 block scaling: float32 scale_inv per block of 128 elements.
if (has_rowwise) {
auto [row_buffer, row_total] = gather_scale_inv(
/*bytes_per_elem=*/sizeof(float),
[](Tensor* t) { return t->rowwise_scale_inv_shape(); },
[](Tensor* t) -> const void* { return t->rowwise_cpu_scale_inv_ptr<float>(); });
grouped.scale_inv = std::move(row_buffer);
NVTEShape row_shape = nvte_make_shape(&row_total, 1);
NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat32, row_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor));
}
if (has_columnwise) {
auto [col_buffer, col_total] = gather_scale_inv(
/*bytes_per_elem=*/sizeof(float),
[](Tensor* t) { return t->columnwise_scale_inv_shape(); },
[](Tensor* t) -> const void* { return t->columnwise_cpu_scale_inv_ptr<float>(); });
grouped.columnwise_scale_inv = std::move(col_buffer);
NVTEShape col_shape = nvte_make_shape(&col_total, 1);
NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat32, col_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor));
}
} else if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
// NVFP4: E4M3 scale_inv per block of 16 elements (swizzled for GEMM, 1 byte per scale).
if (has_rowwise) {
auto [row_buffer, row_total] = gather_scale_inv(
/*bytes_per_elem=*/1,
[](Tensor* t) { return t->rowwise_scale_inv_shape(); },
[](Tensor* t) -> const void* { return t->rowwise_cpu_scale_inv_ptr<fp8e4m3>(); });
grouped.scale_inv = std::move(row_buffer);
NVTEShape row_shape = nvte_make_shape(&row_total, 1);
NVTEBasicTensor row_tensor{grouped.scale_inv.get(), kNVTEFloat8E4M3, row_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &row_tensor, sizeof(row_tensor));
}
if (has_columnwise) {
auto [col_buffer, col_total] = gather_scale_inv(
/*bytes_per_elem=*/1,
[](Tensor* t) { return t->columnwise_scale_inv_shape(); },
[](Tensor* t) -> const void* { return t->columnwise_cpu_scale_inv_ptr<fp8e4m3>(); });
grouped.columnwise_scale_inv = std::move(col_buffer);
NVTEShape col_shape = nvte_make_shape(&col_total, 1);
NVTEBasicTensor col_tensor{grouped.columnwise_scale_inv.get(), kNVTEFloat8E4M3, col_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &col_tensor, sizeof(col_tensor));
}

// Mark as having swizzled scales (required for NVFP4 GEMM)
uint8_t swizzled = 1;
nvte_set_grouped_tensor_param(h, kNVTEGroupedWithGEMMSwizzledScales, &swizzled, sizeof(swizzled));

// Gather per-tensor amax values for NVFP4 global scale computation
auto gather_amax = [&](NVTETensorParam param) -> CudaPtr<> {
// Check if first tensor has this amax
NVTEBasicTensor first_amax = nvte_get_tensor_param(tensors[0]->data(), param);
if (first_amax.data_ptr == nullptr) return CudaPtr<>();

std::vector<float> amax_cpu(num_tensors);
for (size_t i = 0; i < num_tensors; ++i) {
NVTEBasicTensor amax_bt = nvte_get_tensor_param(tensors[i]->data(), param);
NVTE_CHECK(amax_bt.data_ptr != nullptr, "Tensor ", i, " is missing amax");
float val;
NVTE_CHECK_CUDA(cudaMemcpy(&val, amax_bt.data_ptr, sizeof(float), cudaMemcpyDeviceToHost));
amax_cpu[i] = val;
}
CudaPtr<> dev = cuda_alloc(sizeof(float) * num_tensors);
NVTE_CHECK_CUDA(cudaMemcpy(dev.get(), amax_cpu.data(),
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
return dev;
};

grouped.amax_dev = gather_amax(kNVTEAmax);
if (grouped.amax_dev.get()) {
NVTEShape amax_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor amax_tensor{grouped.amax_dev.get(), kNVTEFloat32, amax_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedAmax, &amax_tensor, sizeof(amax_tensor));
}

grouped.columnwise_amax_dev = gather_amax(kNVTEColumnwiseAmax);
if (grouped.columnwise_amax_dev.get()) {
NVTEShape amax_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor amax_tensor{grouped.columnwise_amax_dev.get(), kNVTEFloat32, amax_shape};
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseAmax, &amax_tensor, sizeof(amax_tensor));
}

}

return grouped;
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class Tensor {

NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; }

NVTEShape shape() const noexcept { return tensor_.shape(); }

NVTEShape rowwise_scale_inv_shape() const {
NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!");
return tensor_.get_rowwise_scale_inv().shape;
Expand Down Expand Up @@ -596,6 +598,8 @@ struct GroupedBuffers {
CudaPtr<int64_t> last_dims_dev;
CudaPtr<int64_t> offsets_dev;
CudaPtr<> columnwise_data;
CudaPtr<> amax_dev; // Per-tensor amax for NVFP4 grouped GEMM
CudaPtr<> columnwise_amax_dev; // Per-tensor columnwise amax for NVFP4 grouped GEMM
NVTEShape logical_shape{};
std::vector<int64_t> offsets_host;
std::vector<size_t> tensor_bytes;
Expand Down
7 changes: 5 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2935,10 +2935,13 @@ def _apply_grouped_bias_ref(
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize("use_bias_scale", [False, True])
def test_grouped_gemm_grouped_tensor(z, m, n, k, case, layout, accumulate, use_bias_scale) -> None:
if torch.cuda.get_device_capability() < (9, 0):
pytest.skip("Grouped GEMM requires Hopper (SM90) or newer.")
if torch.cuda.get_device_capability() < (10, 0):
if tex.get_cublasLt_version() < 130400:
pytest.skip("Grouped GEMM on Hopper requires cuBLAS 13.4+.")
if tex.get_cublasLt_version() < 130300:
pytest.skip("Grouped GEMM requires cuBLAS 13.3+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_M

inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }

inline bool is_fp8_block_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_BLOCK_SCALING_1D || mode == NVTE_BLOCK_SCALING_2D;
}

inline size_t product(const std::vector<size_t> &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
Expand Down
Loading
Loading