Skip to content

[Proposal] Update and discuss behavior of SVD #341

@diego898

Description

@diego898

Proposal

This proposal is in two parts, addressing TL current SVD implementation.

Note, the SVD is typically written as:

$$ U,S,V^T = \ldots $$

The important part here is the $V^T$.

Currently, TL returns $V$ not $V^T$, as TL uses torch.svd not torch.linalg.svd, and torch.svd also returned $V$.

https://github.com/neelnanda-io/TransformerLens/blob/347e65235140d942301452a55505c9215b36806a/transformer_lens/FactoredMatrix.py#L118-L125

To address this, we can do one or both of:

  • small: torch.svd in FactoredMatrix is deprecated and should be replaced by torch.linalg.svd. This can be done and keep all current behavior. Note, right now TL uses Vh in FactoredMatrix but should actually be called V, so if this is done, we should also do the renaming.
  • big: change the implementation to instead return V_T as torch.linalg.svd does. This would be a breaking change

Pitch

I'd like to make a pitch for doing both of the above. That is, switch to torch.linalg.svd, and return $V^T$ directly. This would let you do:

U, S, V_T = model.OV.svd()
torch.dist(model.OV, U @ torch.diag(S) @ V_T)

like in the docs, and at least for me, behaves more "like expected"

Note: the above is only conceptual. Right now, I dont think you can use FactoredMatrix directly and instead must use model.OV.AB, and the call to diag depends on the sizes of the others, and should probably be .diag_embed

The code change would be very small. Something like:

Ua, Sa, V_Ta = torch.linalg.svd(model.OV.A,full_matrices=False)
Ub, Sb, V_Tb = torch.linalg.svd(model.OV.B,full_matrices=False)

middle = Sa.diag_embed() @ V_Ta @ Ub @ Sb.diag_embed()
Um, Sm, V_Tm = torch.linalg.svd(middle,full_matrices=False)

Uf = Ua @ Um
Sf = Sm
V_Tf = V_Tm @ V_Tb

return Uf, Sf, V_Tf

I can submit a PR, but thought it might be best to discuss first

Alternatives

The first plan should be considered the 'safe and easy' alternative.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions