Proposal
This proposal is in two parts, addressing TL current SVD implementation.
Note, the SVD is typically written as:
$$
U,S,V^T = \ldots
$$
The important part here is the $V^T$.
Currently, TL returns $V$ not $V^T$, as TL uses torch.svd not torch.linalg.svd, and torch.svd also returned $V$.
https://github.com/neelnanda-io/TransformerLens/blob/347e65235140d942301452a55505c9215b36806a/transformer_lens/FactoredMatrix.py#L118-L125
To address this, we can do one or both of:
- small:
torch.svd in FactoredMatrix is deprecated and should be replaced by torch.linalg.svd. This can be done and keep all current behavior. Note, right now TL uses Vh in FactoredMatrix but should actually be called V, so if this is done, we should also do the renaming.
- big: change the implementation to instead return
V_T as torch.linalg.svd does. This would be a breaking change
Pitch
I'd like to make a pitch for doing both of the above. That is, switch to torch.linalg.svd, and return $V^T$ directly. This would let you do:
U, S, V_T = model.OV.svd()
torch.dist(model.OV, U @ torch.diag(S) @ V_T)
like in the docs, and at least for me, behaves more "like expected"
Note: the above is only conceptual. Right now, I dont think you can use FactoredMatrix directly and instead must use model.OV.AB, and the call to diag depends on the sizes of the others, and should probably be .diag_embed
The code change would be very small. Something like:
Ua, Sa, V_Ta = torch.linalg.svd(model.OV.A,full_matrices=False)
Ub, Sb, V_Tb = torch.linalg.svd(model.OV.B,full_matrices=False)
middle = Sa.diag_embed() @ V_Ta @ Ub @ Sb.diag_embed()
Um, Sm, V_Tm = torch.linalg.svd(middle,full_matrices=False)
Uf = Ua @ Um
Sf = Sm
V_Tf = V_Tm @ V_Tb
return Uf, Sf, V_Tf
I can submit a PR, but thought it might be best to discuss first
Alternatives
The first plan should be considered the 'safe and easy' alternative.
Checklist
Proposal
This proposal is in two parts, addressing TL current SVD implementation.
Note, the SVD is typically written as:
The important part here is the$V^T$ .
Currently, TL returns$V$ not $V^T$ , as TL uses $V$ .
torch.svdnottorch.linalg.svd, andtorch.svdalso returnedhttps://github.com/neelnanda-io/TransformerLens/blob/347e65235140d942301452a55505c9215b36806a/transformer_lens/FactoredMatrix.py#L118-L125
To address this, we can do one or both of:
torch.svdinFactoredMatrixis deprecated and should be replaced bytorch.linalg.svd. This can be done and keep all current behavior. Note, right now TL usesVhinFactoredMatrixbut should actually be calledV, so if this is done, we should also do the renaming.V_Tastorch.linalg.svddoes. This would be a breaking changePitch
I'd like to make a pitch for doing both of the above. That is, switch to$V^T$ directly. This would let you do:
torch.linalg.svd, and returnlike in the docs, and at least for me, behaves more "like expected"
Note: the above is only conceptual. Right now, I dont think you can use
FactoredMatrixdirectly and instead must usemodel.OV.AB, and the call todiagdepends on the sizes of the others, and should probably be.diag_embedThe code change would be very small. Something like:
I can submit a PR, but thought it might be best to discuss first
Alternatives
The first plan should be considered the 'safe and easy' alternative.
Checklist