Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@
}


def parse_hadamard_config_arg(value: str | None):
"""Parse --hadamard_config into string shortcut or dict."""
if value is None:
return None

value = value.strip()
if not value:
return None

if value.startswith("{"):
try:
parsed_value = json.loads(value)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON for --hadamard_config: {exc}") from exc
if not isinstance(parsed_value, dict):
raise ValueError("--hadamard_config JSON must be an object/dict")
return parsed_value

return value


class BasicArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -303,6 +324,17 @@ def __init__(self, *args, **kwargs):
help="Group size for weight quantization.",
)
scheme.add_argument("--asym", action="store_true", help="Use asymmetric quantization instead of symmetric.")
act_sym_group = scheme.add_mutually_exclusive_group()
act_sym_group.add_argument(
"--act_sym",
action="store_true",
help="Use symmetric activation quantization. Overrides the activation default inherited from weight quantization.",
)
act_sym_group.add_argument(
"--act_asym",
action="store_true",
help="Use asymmetric activation quantization. Overrides the activation default inherited from weight quantization.",
)
scheme.add_argument(
"--data_type",
"--dtype",
Expand Down Expand Up @@ -373,6 +405,16 @@ def __init__(self, *args, **kwargs):
choices=["fp8", "float8_e4m3fn"],
help="Data type for static quantize attention. ",
)
scheme.add_argument(
"--hadamard_config",
default=None,
type=str,
help=(
"Optional hadamard/rotation config. "
"Supports shortcuts such as 'default', 'random_hadamard', 'llama_quarot', "
'or a JSON dict string such as {"placement_strategy":"llama_quarot","hadamard_type":"random_hadamard"}.'
),
)
gguf = self.add_argument_group("Double Quant Arguments")
gguf.add_argument(
"--super_group_size", default=None, type=int, help="Super group size for double quantization."
Expand Down Expand Up @@ -541,6 +583,20 @@ def start(recipe="default"):
tune(args)


def resolve_symmetry_args(args):
sym = None # inherit the preset/default weight symmetry unless explicitly overridden
if args.asym:
sym = False

act_sym = None # inherit from weight symmetry unless explicitly overridden
if getattr(args, "act_asym", False):
act_sym = False
elif getattr(args, "act_sym", False):
act_sym = True

return sym, act_sym


def tune(args):
assert args.model or args.model_name, "[model] or --model MODEL_NAME should be set."
if args.model is None:
Expand Down Expand Up @@ -612,9 +668,7 @@ def tune(args):
)

enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False
sym = None # the default value should be None now
if args.asym: # if the scheme is asym, how to set it to sym is an issue
sym = False
sym, act_sym = resolve_symmetry_args(args)
act_dynamic = None
if args.disable_act_dynamic:
act_dynamic = False
Expand Down Expand Up @@ -657,6 +711,7 @@ def tune(args):
data_type=args.data_type,
act_bits=args.act_bits,
act_group_size=args.act_group_size,
act_sym=act_sym,
act_data_type=args.act_data_type,
act_dynamic=act_dynamic,
super_bits=args.super_bits,
Expand Down Expand Up @@ -705,6 +760,8 @@ def tune(args):
low_cpu_mem_usage=low_cpu_mem_usage,
)

hadamard_config = parse_hadamard_config_arg(args.hadamard_config)

autoround: BaseCompressor = AutoRound(
model=model_name,
platform=args.platform,
Expand All @@ -727,6 +784,7 @@ def tune(args):
model_dtype=args.model_dtype,
momentum=args.momentum,
trust_remote_code=not args.disable_trust_remote_code,
hadamard_config=hadamard_config,
)

model_name = args.model.rstrip("/")
Expand Down
22 changes: 19 additions & 3 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,13 +562,20 @@ def __init__(
from auto_round.experimental.transform.apply import apply_hadamard_transform
from auto_round.experimental.utils import check_supported_schemes, normalize_hadamard_config

check_supported_schemes(self.scheme)
normalized_hadamard_config = normalize_hadamard_config(hadamard_config)
check_supported_schemes(self.scheme, normalized_hadamard_config)

self.model = apply_hadamard_transform(
self.model, hadamard_config, need_calibration=True if self.iters > 0 else False
self.model,
normalized_hadamard_config,
need_calibration=True if self.iters > 0 else False,
target_device=self.device,
)
rotation_device = getattr(self.model, "_autoround_llama_quarot_rotation_device", None)
if rotation_device is not None:
logger.info(f"Llama QuaRot offline rotation device: {rotation_device}")

self.hadamard_config = normalize_hadamard_config(hadamard_config)
self.hadamard_config = normalized_hadamard_config

def _gen_auto_scheme(self) -> dict[str, dict]:
if self.mllm:
Expand Down Expand Up @@ -1695,6 +1702,15 @@ def configure_layer_config(self, enable_gguf_official_mixed: None | bool = True)
fill_default_value=fill_default_value,
)

if (getattr(self, "hadamard_config", {}) or {}).get("placement_strategy") == "llama_quarot":
from auto_round.experimental.transform.llama_quarot import apply_llama_quarot_layer_config_overrides

self.layer_config = apply_llama_quarot_layer_config_overrides(
self.model,
self.layer_config,
warn_fn=logger.warning,
)

def _adjust_immediate_packing_and_saving(self):
formats = getattr(self, "formats", [])
if len(formats) == 1 and not formats[0].is_fake() and self.inplace:
Expand Down
5 changes: 3 additions & 2 deletions auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ def run_model_evaluation(model, tokenizer, autoround, folders, formats, device_s
# Check if GGUF model
eval_gguf_model = any(file.endswith("gguf") for file in os.listdir(eval_folder))

# Determine if model instance evaluation is needed
need_model_instance = (autoround.act_bits <= 8 and formats[-1] == "fake") or eval_gguf_model
# Fake quant evaluation must use the in-memory model instance so that
# runtime transforms such as online hadamard hooks remain attached.
need_model_instance = (formats[-1] == "fake") or eval_gguf_model

if need_model_instance:
# Load or prepare model instance
Expand Down
39 changes: 39 additions & 0 deletions auto_round/experimental/transform/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from auto_round.experimental.qmodules.mx import MXQuantLinearBase
from auto_round.experimental.transform.hadamard_config import HadamardConfig
from auto_round.experimental.transform.hadamards import build_hadamard_transform
from auto_round.experimental.transform.llama_quarot import (
LLAMA_QUAROT_STRATEGY,
apply_llama_quarot_weight_transform,
llama_quarot_online_transform,
register_llama_quarot_online_transforms,
)
from auto_round.experimental.utils import is_triton_kernel_available, normalize_hadamard_config

__all__ = ["apply_hadamard_transform"]
Expand All @@ -17,6 +23,7 @@ def apply_hadamard_transform(
config: str | dict | HadamardConfig | None,
need_calibration: bool = False,
location: str = "weight",
target_device: str | torch.device | None = None,
use_tqdm=True,
desc=None,
):
Expand Down Expand Up @@ -57,6 +64,38 @@ def apply_hadamard_transform(
if not isinstance(config, HadamardConfig):
config = HadamardConfig(**config)

if config.placement_strategy == LLAMA_QUAROT_STRATEGY:
if location == "weight":
model = apply_llama_quarot_weight_transform(
model, config, use_tqdm=use_tqdm, desc=desc, target_device=target_device
)
register_llama_quarot_online_transforms(
model,
use_tqdm=use_tqdm,
desc="Register Llama QuaRot online transforms",
force_fp32=config.llama_quarot_online_force_fp32,
)

from auto_round.experimental.transform.patch_modules import (
patch_wrapperlinear_forward_to_apply_activation_transform,
patch_wrapperwalayer_forward_to_apply_activation_transform,
)

# QuaRot keeps `o_proj`/`down_proj` online at runtime. Wrapper-backed
# paths (both tuning and RTN/fake) bypass the original module call
# sites, so we must patch wrapper forwards regardless of `iters`.
patch_wrapperlinear_forward_to_apply_activation_transform(llama_quarot_online_transform)
patch_wrapperwalayer_forward_to_apply_activation_transform(llama_quarot_online_transform)
elif location == "input":
register_llama_quarot_online_transforms(
model, use_tqdm=use_tqdm, desc=desc, force_fp32=config.llama_quarot_online_force_fp32
)
else:
raise NotImplementedError(f"Unsupported hadamard transform location: {location}")

setattr(model, "hadamard_config", config)
return model

modules_config = [
(name, module, config)
for name, module in model.named_modules()
Expand Down
15 changes: 15 additions & 0 deletions auto_round/experimental/transform/hadamard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class HadamardConfig(BaseModel):

hadamard_type: str = Field(default="hadamard")

placement_strategy: str = Field(default="all_linears")

# llama_quarot specific options
llama_quarot_online_force_fp32: bool = Field(default=True)
llama_quarot_strict: bool = Field(default=True)
llama_quarot_center_embeddings: bool = Field(default=False)

# for random hadamard transform
random_seed: bool = Field(default=False, exclude=True)

Expand All @@ -26,3 +33,11 @@ def validate_hadamard_type(cls, v: str) -> str:
if v not in allowed:
raise ValueError(f"Unsupported hadamard_type: {v}. Supported values: {sorted(allowed)}")
return v

@field_validator("placement_strategy")
@classmethod
def validate_placement_strategy(cls, v: str) -> str:
allowed = {"all_linears", "llama_quarot"}
if v not in allowed:
raise ValueError(f"Unsupported placement_strategy: {v}. Supported values: {sorted(allowed)}")
return v
Loading
Loading