Add F4 (float4_e2m1fn_x2) and F8_E8M0 (float8_e8m0fnu) dtype support#69
Add F4 (float4_e2m1fn_x2) and F8_E8M0 (float8_e8m0fnu) dtype support#69uasind wants to merge 1 commit into
Conversation
…ypes
Models like DeepSeek V4-Flash use two additional dtypes that were missing
from fastsafetensors's DType enum:
- F8_E8M0 (torch.float8_e8m0fnu, PyTorch 2.5+): unsigned 8-bit exponent-only
format used for per-tile quantization scales. One byte per element.
- F4 (torch.float4_e2m1fn_x2, PyTorch 2.10+): packed FP4 format, two 4-bit
values per byte. safetensors stores the shape in FP4-element count while
PyTorch float4_e2m1fn_x2 counts packed pairs, so the byte size per logical
element is 0.5.
Without these variants, add_filenames() raises ValueError on any safetensors
file that contains F4 or F8_E8M0 tensors, making fastsafetensors unusable with
those models.
Changes:
- st_types.py: add F4 and F8_E8M0 to DType enum
- frameworks/_torch.py:
* map DType.F8_E8M0 -> torch.float8_e8m0fnu (guarded by hasattr)
* map DType.F4 -> torch.float4_e2m1fn_x2 (guarded by hasattr)
* add U8 workaround for both (no NCCL support)
* get_dtype_size() returns 0.5 for F4 (two FP4 values per byte)
* get_storage_shape() collapses packed sub-byte shapes to flat byte count
so the DLPack tensor does not overread the buffer
* get_native_shape() restores the PyTorch-native shape after DLPack import
* get_empty_tensor() uses get_native_shape() so shape is correct for alloc
* TorchTensor.reshape() added to support shape adjustment in common.py
- frameworks/__init__.py:
* get_dtype_size() return type changed from int to float
* get_storage_shape() added (default: identity)
* get_native_shape() added (default: identity)
* TensorBase.reshape() added (default: raises NotImplementedError)
- dlpack.py: add DLPack mappings for F4 and F8_E8M0 as opaque uint8 bytes
- common.py:
* validation: nbytes = int(nelements * get_dtype_size()) for sub-byte safety
* get_tensors: use get_storage_shape() and get_native_shape() for packed types
- tests: add test_float4_e2m1fn_x2 and test_float8_e8m0fnu
Validated against DeepSeek V4-Flash MP=2 safetensors shards (~82 GB each).
All six dtypes (BF16, F32, F4, F8_E4M3, F8_E8M0, I64) load bit-exactly.
|
@uasind |
There was a problem hiding this comment.
@uasind
Can you please resolve lint issues, DCO, and my comments. Thanks!
(added) please rebase main before pushing new changes.
|
|
||
| ratio = int(round(1.0 / size)) # e.g. 2 for F4 (2 FP4 per byte) | ||
| if len(st_shape) > 1: | ||
| return list(st_shape[:-1]) + [st_shape[-1] // ratio] |
There was a problem hiding this comment.
Can you check st_shape[-1] % ratio == 0 and raise Exception on malformed safetensors files?
|
@uasind So, I will wait for your responses by the next Sunday (5/17) JST. After that period, I will close this PR and implement F4 and F8_E8M0 supports as another PR with your name in the commit message. Thank you for your understanding. |
Summary
Adds support for two dtypes used by modern mixed-precision models (DeepSeek V4-Flash, DeepSeek V3, etc.) that currently cause
add_filenames()to fail:F8_E8M0 (
torch.float8_e8m0fnu, PyTorch 2.5+): unsigned 8-bit exponent-only format used for per-tile quantization scales. One byte per element. Straightforward addition — same storage size as F8_E4M3.F4 (
torch.float4_e2m1fn_x2, PyTorch 2.10+): packed FP4 format, two 4-bit values per byte. safetensors stores the shape in FP4-element count while PyTorchfloat4_e2m1fn_x2counts packed pairs (one byte each). This requires shape adjustment on load.Without these, any safetensors file containing F4 or F8_E8M0 tensors raises:
Changes
fastsafetensors/st_types.pyF4andF8_E8M0toDTypeenumfastsafetensors/frameworks/_torch.pyDType.F8_E8M0→torch.float8_e8m0fnu(guarded byhasattr)DType.F4→torch.float4_e2m1fn_x2(guarded byhasattr)U8NCCL workaround for both (NCCL has no float8_e8m0/float4 support)get_dtype_size()returns0.5for F4 (two FP4 values per byte)get_storage_shape(): collapses packed sub-byte shapes to flat byte count for DLPack so the workaround-dtype view doesn't overread the bufferget_native_shape(): converts safetensors FP4-element shape to PyTorch packed-pair shapeget_empty_tensor(): usesget_native_shape()for correct allocation shapeTorchTensor.reshape(): needed bycommon.pyafter DLPack importfastsafetensors/frameworks/__init__.pyget_dtype_size()return type:int→float(to accommodate0.5for F4)get_storage_shape()added (default: identity — no change for existing dtypes)get_native_shape()added (default: identity — no change for existing dtypes)TensorBase.reshape()added (default:NotImplementedError)fastsafetensors/dlpack.py(kDLUInt, 8, 1)— consistent with their U8 workaroundfastsafetensors/common.pynbytes = int(nelements * get_dtype_size())— safe for fractional sizesget_tensors(): usesget_storage_shape()/get_native_shape()for packed dtypestests/test_fastsafetensors.pytest_float8_e8m0fnu: bit-exact round-trip via_test_typetest_float4_e2m1fn_x2: bit-exact round-trip (uses uint8 view, since randn/cast don't support float4)F4 shape handling detail
safetensors stores an F4 weight of logical shape
[2048, 4096](8M FP4 values) as 4 MiB (0.5 bytes per FP4 value).torch.float4_e2m1fn_x2represents paired values:[2048, 4096]in that dtype would be 8 MiB — wrong. The correct PyTorch shape is[2048, 2048](4M packed pairs = 4 MiB).get_storage_shape()handles DLPack by using a flat[4194304]uint8 shape (avoids buffer overread).get_native_shape()then reshapes to[2048, 2048]after the view.get_empty_tensor()also usesget_native_shape()so broadcast/scatter allocations are correctly sized.Validation
Tested against DeepSeek V4-Flash MP=2 safetensors shards (~82 GiB each) containing all six dtypes (BF16, F32, F4, F8_E4M3, F8_E8M0, I64). With this PR:
add_filenames()succeeds on V4-Flash shardssafetensors.torch.load_fileNotes
get_storage_shape()andget_native_shape()default to identity onFrameworkOpBase, so non-PyTorch frameworks (e.g. Paddle) are unaffected unless they add their own F4/F8_E8M0 dtype mappings.hasattr()so older PyTorch versions still work for all currently-supported dtypes.