Skip to content

Refactor: eliminate _compute_outer_product_list via _transform_grad_for_outer_product hook#262

Open
runame wants to merge 2 commits into
facebookresearch:mainfrom
runame:pr2/kl-outer-product-refactor
Open

Refactor: eliminate _compute_outer_product_list via _transform_grad_for_outer_product hook#262
runame wants to merge 2 commits into
facebookresearch:mainfrom
runame:pr2/kl-outer-product-refactor

Conversation

@runame
Copy link
Copy Markdown
Contributor

@runame runame commented May 7, 2026

Summary

  • Replace the _compute_outer_product_list override pattern with a _transform_grad_for_outer_product hook in ClassicShampooPreconditionerList.
  • The base class computes outer products uniformly; KL variants override the hook to precondition the gradient before the outer product, instead of overriding the entire outer product computation.
  • This makes it straightforward to add new variants (e.g., per-factor eigenvalue correction) that need custom outer product behavior without duplicating the outer product loop.

Stack

This PR is part of a stack adding per-factor eigenvalue correction to Distributed Shampoo:

  1. Refactor: extract shared EigendecompositionBasedShampooKroneckerFactorsUnwrapped base class #261 — extract shared base class
  2. This PR — add _transform_grad_for_outer_product hook (KL refactor)
  3. Per-factor eigenvalue correction (implementation + tests)
  4. Eigenvalue EMA over per-step outer products

Test plan

  • Existing tests pass (distributed_shampoo/tests/, distributed_shampoo/preconditioner/tests/)
  • mypy clean (make type-check)
  • ruff clean
  • No new tests needed — pure refactor with no behavior change.

Generated with Claude Code

runame and others added 2 commits May 7, 2026 09:34
…rsUnwrapped base class

Consolidate duplicated eigendecomposition logic from EigendecomposedShampooKroneckerFactorsUnwrapped
and EigenvalueCorrectedShampooKroneckerFactorsUnwrapped into a shared base class. The base class
provides _perform_eigendecomposition and _amortized_computation, with subclass behavior controlled
via hasattr checks on field presence.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…or_outer_product hook

Inline the outer product loop into BaseShampooPreconditionerList._update_factor_matrices
and introduce _transform_grad_for_outer_product as the single extension point. The base
returns grad unchanged; KL-Shampoo subclasses override it to precondition the gradient.
This eliminates _compute_outer_product_list from all three classes that defined it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 7, 2026
@meta-codesync
Copy link
Copy Markdown

meta-codesync Bot commented May 12, 2026

@hjmshi has imported this pull request. If you are a Meta employee, you can view this in D104875279.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant