-
Notifications
You must be signed in to change notification settings - Fork 95
Expand file tree
/
Copy pathcriterions.py
More file actions
370 lines (281 loc) · 12.2 KB
/
criterions.py
File metadata and controls
370 lines (281 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
#
# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables
# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+)
# Source code:
# https://github.com/AdaptiveMotorControlLab/CEBRA
#
# Please see LICENSE.md for the full license document:
# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Criterions for contrastive learning
Different criterions can be used for learning embeddings with CEBRA. The common
interface of criterions implementing the generalized InfoNCE metric is given by
:py:class:`BaseInfoNCE`.
Criterions are available for fixed and learnable temperatures, as well as different
similarity measures.
Note that criterions can have trainable parameters, which are automatically handled
by the training loops implemented in :py:class:`cebra.solver.base.Solver` classes.
"""
import math
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
def _compile(fn):
"""Apply ``torch.compile`` when available, falling back to uncompiled.
``torch.compile`` is the recommended replacement for ``torch.jit.script``
starting from PyTorch 2.0. In environments where the compiler backend is
not available (e.g. certain CI configurations or incomplete installations),
the function is returned unchanged so that correctness is preserved.
A :class:`UserWarning` is emitted when the fallback path is taken.
"""
try:
return torch.compile(fn)
except (ImportError, RuntimeError, TypeError) as exc:
warnings.warn(
f"torch.compile is unavailable; falling back to uncompiled "
f"{fn.__name__!r}. Reason: {exc}",
UserWarning,
stacklevel=2,
)
return fn
@_compile
def dot_similarity(ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Cosine similarity the ref, pos and negative pairs
Args:
ref: The reference samples of shape `(n, d)`.
pos: The positive samples of shape `(n, d)`.
neg: The negative samples of shape `(n, d)`.
Returns:
The similarity between reference samples and positive samples of shape `(n,)`, and
the similarities between reference samples and negative samples of shape `(n, n)`.
"""
pos_dist = torch.einsum("ni,ni->n", ref, pos)
neg_dist = torch.einsum("ni,mi->nm", ref, neg)
return pos_dist, neg_dist
@_compile
def euclidean_similarity(
ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Negative L2 distance between the ref, pos and negative pairs
Args:
ref: The reference samples of shape `(n, d)`.
pos: The positive samples of shape `(n, d)`.
neg: The negative samples of shape `(n, d)`.
Returns:
The similarity between reference samples and positive samples of shape `(n,)`, and
the similarities between reference samples and negative samples of shape `(n, n)`.
"""
ref_sq = torch.einsum("ni->n", ref**2)
pos_sq = torch.einsum("ni->n", pos**2)
neg_sq = torch.einsum("ni->n", neg**2)
pos_cosine, neg_cosine = dot_similarity(ref, pos, neg)
pos_dist = -(ref_sq + pos_sq - 2 * pos_cosine)
neg_dist = -(ref_sq[:, None] + neg_sq[None] - 2 * neg_cosine)
return pos_dist, neg_dist
@_compile
def infonce(
pos_dist: torch.Tensor, neg_dist: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""InfoNCE implementation
See :py:class:`BaseInfoNCE` for reference.
Note:
- The behavior of this function changed beginning in CEBRA 0.3.0.
The InfoNCE implementation is numerically stabilized.
"""
with torch.no_grad():
c, _ = neg_dist.max(dim=1, keepdim=True)
c = c.detach()
pos_dist = pos_dist - c.squeeze(1)
neg_dist = neg_dist - c
align = (-pos_dist).mean()
uniform = torch.logsumexp(neg_dist, dim=1).mean()
c_mean = c.mean()
align_corrected = align - c_mean
uniform_corrected = uniform + c_mean
return align + uniform, align_corrected, uniform_corrected
class ContrastiveLoss(nn.Module):
"""Base class for contrastive losses.
Note:
- Added in 0.0.2.
"""
def forward(
self, ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the contrastive loss.
Args:
ref: The reference samples of shape `(n, d)`.
pos: The positive samples of shape `(n, d)`.
neg: The negative samples of shape `(n, d)`.
"""
raise NotImplementedError()
class BaseInfoNCE(ContrastiveLoss):
r"""Base class for all InfoNCE losses.
Given a similarity measure :math:`\phi` which will be implemented by the subclasses
of this class, the generalized InfoNCE loss is computed as
.. math::
\sum_{i=1}^n - \phi(x_i, y^{+}_i) + \log \sum_{j=1}^{n} e^{\phi(x_i, y^{-}_{ij})}
where :math:`n` is the batch size, :math:`x` are the reference samples (``ref``),
:math:`y^{+}` are the positive samples (``pos``) and :math:`y^{-}` are the negative
samples (``neg``).
"""
def _distance(self, ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor]:
"""The similarity measure.
Args:
ref: The reference samples of shape `(n, d)`.
pos: The positive samples of shape `(n, d)`.
neg: The negative samples of shape `(n, d)`.
Returns:
The distance between reference samples and positive samples of shape `(n,)`, and
the distances between reference samples and negative samples of shape `(n, n)`.
"""
raise NotImplementedError()
def forward(self, ref, pos,
neg) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the InfoNCE loss.
Args:
ref: The reference samples of shape `(n, d)`.
pos: The positive samples of shape `(n, d)`.
neg: The negative samples of shape `(n, d)`.
See Also:
:py:class:`BaseInfoNCE`.
"""
pos_dist, neg_dist = self._distance(ref, pos, neg)
return infonce(pos_dist, neg_dist)
class FixedInfoNCE(BaseInfoNCE):
"""InfoNCE base loss with a fixed temperature.
Attributes:
temperature:
The softmax temperature
"""
def __init__(self, temperature: float = 1.0):
super().__init__()
self.temperature = temperature
class LearnableInfoNCE(BaseInfoNCE):
"""InfoNCE base loss with a learnable temperature.
Attributes:
temperature:
The current value of the learnable temperature parameter.
min_temperature:
The minimum temperature to use. Increase the minimum temperature
if you encounter numerical issues during optimization.
"""
def __init__(self,
temperature: float = 1.0,
min_temperature: Optional[float] = None):
super().__init__()
if min_temperature is None:
self.max_inverse_temperature = math.inf
else:
self.max_inverse_temperature = 1.0 / min_temperature
log_inverse_temperature = torch.tensor(
math.log(1.0 / float(temperature)))
self.log_inverse_temperature = nn.Parameter(log_inverse_temperature)
self.min_temperature = min_temperature
@torch.jit.export
def _prepare_inverse_temperature(self) -> torch.Tensor:
"""Compute the current inverse temperature."""
inverse_temperature = torch.exp(self.log_inverse_temperature)
inverse_temperature = torch.clamp(inverse_temperature,
max=self.max_inverse_temperature)
return inverse_temperature
@property
def temperature(self) -> float:
with torch.no_grad():
return 1.0 / self._prepare_inverse_temperature().item()
class FixedCosineInfoNCE(FixedInfoNCE):
r"""Cosine similarity function with fixed temperature.
The similarity metric is given as
.. math ::
\phi(x, y) = x^\top y / \tau
with fixed temperature :math:`\tau > 0`.
Note that this loss function should typically only be used with normalized.
This class itself does *not* perform any checks. Ensure that :math:`x` and
:math:`y` are normalized.
"""
@torch.jit.export
def _distance(self, ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pos_dist, neg_dist = dot_similarity(ref, pos, neg)
return pos_dist / self.temperature, neg_dist / self.temperature
class FixedEuclideanInfoNCE(FixedInfoNCE):
r"""L2 similarity function with fixed temperature.
The similarity metric is given as
.. math ::
\phi(x, y) = - \| x - y \| / \tau
with fixed temperature :math:`\tau > 0`.
"""
@torch.jit.export
def _distance(self, ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
pos_dist, neg_dist = euclidean_similarity(ref, pos, neg)
return pos_dist / self.temperature, neg_dist / self.temperature
class LearnableCosineInfoNCE(LearnableInfoNCE):
r"""Cosine similarity function with a learnable temperature.
Like :py:class:`FixedCosineInfoNCE`, but with a learnable temperature
parameter :math:`\tau`.
"""
@torch.jit.export
def _distance(self, ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
inverse_temperature = self._prepare_inverse_temperature()
pos, neg = dot_similarity(ref, pos, neg)
return pos * inverse_temperature, neg * inverse_temperature
class LearnableEuclideanInfoNCE(LearnableInfoNCE):
r"""L2 similarity function with fixed temperature.
Like :py:class:`FixedEuclideanInfoNCE`, but with a learnable temperature
parameter :math:`\tau`.
"""
@torch.jit.export
def _distance(self, ref: torch.Tensor, pos: torch.Tensor,
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
inverse_temperature = self._prepare_inverse_temperature()
pos, neg = euclidean_similarity(ref, pos, neg)
return pos * inverse_temperature, neg * inverse_temperature
# NOTE(stes): old aliases used in various locations in the codebase. Should be
# deprecated at some point.
InfoNCE = FixedCosineInfoNCE
InfoMSE = FixedEuclideanInfoNCE
class NCE(ContrastiveLoss):
"""Noise contrastive estimation (Gutman & Hyvarinen, 2012)
Attributes:
temperature (float): The softmax temperature
negative_weight (float): Relative weight of the negative samples
reduce (str): How to reduce the negative samples. Can be
``sum`` or ``mean``.
"""
def __init__(self, temperature=1.0, negative_weight=1.0, reduce="mean"):
super().__init__()
self.temperature = temperature
self.negative_weight = negative_weight
assert reduce in ["mean", "sum"]
self._reduce = getattr(torch, reduce)
def forward(self, ref, pos, neg):
"""Compute the NCE loss.
Args:
ref: The reference samples of shape `(n, d)`.
pos: The positive samples of shape `(n, d)`.
neg: The negative samples of shape `(n, d)`.
See Also:
:py:class:`NCE`.
"""
pos_dist = torch.einsum("ni,ni->n", ref, pos) / self.temperature
neg_dist = torch.einsum("ni,mi->nm", ref, neg) / self.temperature
align = F.logsigmoid(pos_dist)
uniform = self._reduce(F.logsigmoid(-neg_dist), dim=1)
return align + self.negative_weight * uniform, align, uniform