Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/benchmark_rht_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext

from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer

scale_padding_to = 1
permute_scale = False

TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
torch.bfloat16: TE_DType.kBFloat16,
}


Expand All @@ -31,7 +31,7 @@ def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
fp4_dtype=TE_DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/quickstart_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):

def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType

fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
fp8_type = TE_DType.kFloat8E4M3 if fp8_format == "e4m3" else TE_DType.kFloat8E5M2
scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type)
Expand Down
13 changes: 7 additions & 6 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch import (
autocast,
Expand Down Expand Up @@ -323,34 +324,34 @@ def run_dpa_with_cp(
).cuda()
if scaling_mode == "delayed":
qkv_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=TE_DType.kFloat8E4M3,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
dout_quantizer = Float8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=TE_DType.kFloat8E5M2,
scale=torch.tensor([1], dtype=torch.float32).cuda(),
amax=torch.tensor([0], dtype=torch.float32).cuda(),
)
if scaling_mode == "current":
qkv_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=TE_DType.kFloat8E4M3,
device="cuda",
)
dout_quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=TE_DType.kFloat8E5M2,
device="cuda",
)
if scaling_mode == "mxfp8":
qkv_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=TE_DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
)
qkv_quantizer.optimize_for_gemm = True
qkv_quantizer.internal = False
dout_quantizer = MXFP8Quantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=TE_DType.kFloat8E5M2,
rowwise=True,
columnwise=True,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/debug/run_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.distributed as dist
import transformer_engine
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch import is_fp8_available
Expand Down Expand Up @@ -683,7 +683,7 @@ def _run_test_with_combinations(
)

# test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
dtype_options = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2, None]
_run_test_with_combinations(
test_fake_quant_fp8,
dtype_options,
Expand Down
14 changes: 7 additions & 7 deletions tests/pytorch/debug/test_api_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import torch
from transformer_engine.pytorch import Float8Tensor, Float8Quantizer
from transformer_engine.pytorch.constants import TE_DType

import nvdlfw_inspect.api as debug_api

try:
import transformer_engine
import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine package.")
exit(1)
Expand Down Expand Up @@ -128,12 +128,12 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
default_quantizer1 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=TE_DType.kFloat8E4M3,
)
default_quantizer2 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E5M2,
fp8_dtype=TE_DType.kFloat8E5M2,
)

output1 = debug_api.transformer_engine.modify_tensor(
Expand All @@ -145,7 +145,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
tensor=tensor,
)
assert type(output1) == Float8Tensor
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
assert output1._fp8_dtype == TE_DType.kFloat8E4M3

output2 = debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
Expand All @@ -156,7 +156,7 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
iteration=0,
)
assert type(output2) == Float8Tensor
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
assert output2._fp8_dtype == TE_DType.kFloat8E5M2

assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1",
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=TE_DType.kFloat8E4M3,
)
tensor_fp8 = quantizer(tensor)

Expand Down Expand Up @@ -372,7 +372,7 @@ def log_stats():
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_dtype=TE_DType.kFloat8E4M3,
)

def fp8_tensor(t):
Expand Down
44 changes: 22 additions & 22 deletions tests/pytorch/debug/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.quantization import _default_sf_compute
from transformer_engine.pytorch import (
Expand Down Expand Up @@ -57,7 +57,7 @@ def _cast_to_fp8(tensor, scale, dtype):


def _get_current_scale(tensor, fp8_dtype):
if fp8_dtype == tex.DType.kFloat8E4M3:
if fp8_dtype == TE_DType.kFloat8E4M3:
fp8_max = Format.E4M3.value.max_fwd
else:
fp8_max = Format.E5M2.value.max_fwd
Expand Down Expand Up @@ -93,19 +93,19 @@ def _emulate_linear(
input: torch.Tensor,
weight: torch.Tensor,
fprop_fp8: bool = False,
fprop_input_fake_quant: tex.DType = None,
fprop_input_fake_quant: TE_DType = None,
fprop_input_scale: torch.Tensor = None,
fprop_weight_fake_quant: tex.DType = None,
fprop_weight_fake_quant: TE_DType = None,
fprop_weight_scale: torch.Tensor = None,
dgrad_fp8: bool = False,
dgrad_gradient_fake_quant: tex.DType = None,
dgrad_gradient_fake_quant: TE_DType = None,
dgrad_gradient_scale: torch.Tensor = None,
dgrad_weight_fake_quant: tex.DType = None,
dgrad_weight_fake_quant: TE_DType = None,
dgrad_weight_scale: torch.Tensor = None,
wgrad_fp8: bool = False,
wgrad_gradient_fake_quant: tex.DType = None,
wgrad_gradient_fake_quant: TE_DType = None,
wgrad_gradient_scale: torch.Tensor = None,
wgrad_input_fake_quant: tex.DType = None,
wgrad_input_fake_quant: TE_DType = None,
wgrad_input_scale: torch.Tensor = None,
loss_multiplier: float = 1.0,
activation_sync=None,
Expand All @@ -116,10 +116,10 @@ def _emulate_linear(
activation = _fp8_gemm_kernel(
input,
_scalar(fprop_input_scale or 1.0),
tex.DType.kFloat8E4M3,
TE_DType.kFloat8E4M3,
weight,
_scalar(fprop_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
TE_DType.kFloat8E4M3,
_2X_ACC_FPROP,
)
activation = activation.clone().detach().contiguous().requires_grad_(True)
Expand Down Expand Up @@ -152,10 +152,10 @@ def _emulate_linear(
dgrad = _fp8_gemm_kernel(
weight.T,
_scalar(dgrad_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
TE_DType.kFloat8E4M3,
gradient,
_scalar(dgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
TE_DType.kFloat8E5M2,
_2X_ACC_DGRAD,
).T
else:
Expand All @@ -176,10 +176,10 @@ def _emulate_linear(
wgrad = _fp8_gemm_kernel(
input.T,
_scalar(wgrad_input_scale or 1.0),
tex.DType.kFloat8E4M3,
TE_DType.kFloat8E4M3,
gradient.T,
_scalar(wgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
TE_DType.kFloat8E5M2,
_2X_ACC_WGRAD,
).T
else:
Expand Down Expand Up @@ -470,17 +470,17 @@ def set_scaling_factors(model, input_kwargs, fp8_kwargs):
def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs):
# Compute per tensor scaling factor if respective flag in input_kwargs is set.
if input_kwargs["fprop_inp"]:
fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["fprop_input_scale"] = TE_DType.kFloat8E4M3
if input_kwargs["fprop_weight"]:
fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["fprop_weight_scale"] = TE_DType.kFloat8E4M3
if input_kwargs["dgrad_grad"]:
fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2
fp8_kwargs["dgrad_gradient_scale"] = TE_DType.kFloat8E5M2
if input_kwargs["dgrad_weight"]:
fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["dgrad_weight_scale"] = TE_DType.kFloat8E4M3
if input_kwargs["wgrad_grad"]:
fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2
fp8_kwargs["wgrad_gradient_scale"] = TE_DType.kFloat8E5M2
if input_kwargs["wgrad_input"]:
fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3
fp8_kwargs["wgrad_input_scale"] = TE_DType.kFloat8E4M3


@create_config_file
Expand Down Expand Up @@ -651,7 +651,7 @@ def init_and_warmup():


all_combinations = list(
itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6)
itertools.product([TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2, None], repeat=6)
)
subset_combinations = random.sample(all_combinations, 10)

Expand Down Expand Up @@ -687,7 +687,7 @@ def test_fake_quant_fp8(
def fake_quant_fp8_create_config(
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
):
format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"}
format_to_str = {TE_DType.kFloat8E4M3: "FP8E4M3", TE_DType.kFloat8E5M2: "FP8E5M2"}
gemms = ""

def _add_tensor(quant_format, tensor):
Expand Down
5 changes: 3 additions & 2 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
import transformer_engine.pytorch.cpp_extensions as tex
from transformer_engine.pytorch.cpp_extensions.gemm import get_cublas_workspace_size_bytes
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.module.base import fill_userbuffers_buffer_for_all_gather

warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand Down Expand Up @@ -473,7 +474,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
if opts.quantization == "fp8":
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms = 6 if ub_obj2 is not None else 3
fp8_dtype = tex.DType.kFloat8E4M3
fp8_dtype = TE_DType.kFloat8E4M3
fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda")
fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda")

Expand Down Expand Up @@ -516,7 +517,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype
)
elif opts.quantization == "mxfp8":
fp8_dtype = tex.DType.kFloat8E4M3
fp8_dtype = TE_DType.kFloat8E4M3
inp_quantizer = MXFP8Quantizer(fp8_dtype, columnwise=False)
ker_quantizer = MXFP8Quantizer(fp8_dtype)
if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS:
Expand Down
13 changes: 6 additions & 7 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
from torch import nn
import torch.distributed as dist
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Expand All @@ -27,7 +26,7 @@
QParams,
)
from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE, TE_DType
from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors

Expand Down Expand Up @@ -399,7 +398,7 @@ def _test_quantizer(input_dtype, fp8_dtype):

Args:
input_dtype (torch.dtype): The data type of the input.
fp8_dtype (tex.DType): The data type of the fp8.
fp8_dtype (TE_DType): The data type of the fp8.
"""

M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE
Expand Down Expand Up @@ -443,7 +442,7 @@ def test_quantizer():
return

input_dtypes = [torch.float32, torch.bfloat16]
fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
fp8_dtypes = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2]

for input_dtype in input_dtypes:
for fp8_dtype in fp8_dtypes:
Expand Down Expand Up @@ -514,7 +513,7 @@ def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):

Args:
input_dtype (torch.dtype): The data type of the input.
low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
low_precision_dtype (TE_DType): The data type of the low precision, can be fp4 or fp8.
"""

M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
Expand Down Expand Up @@ -623,8 +622,8 @@ def test_quantized_all_gather():
return

input_dtypes = [torch.bfloat16]
fp4_dtype = [tex.DType.kFloat4E2M1]
fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
fp4_dtype = [TE_DType.kFloat4E2M1]
fp8_dtype = [TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2]
quantizer_cls_nvfp4 = [NVFP4Quantizer]
# add FP8 quantizers if needed
quantizer_cls_fp8 = []
Expand Down
Loading
Loading