Skip to content

Commit 00d7cb0

Browse files
CopilotMMathisLab
andauthored
Replace torch.jit.script with torch.compile (PyTorch 2.x best practices)
Agent-Logs-Url: https://github.com/AdaptiveMotorControlLab/CEBRA/sessions/289d0438-5f8a-480c-9428-c2d5586a8cea Co-authored-by: MMathisLab <28102185+MMathisLab@users.noreply.github.com>
1 parent 0aaf52e commit 00d7cb0

2 files changed

Lines changed: 32 additions & 7 deletions

File tree

cebra/models/criterions.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,36 @@
3333
"""
3434

3535
import math
36+
import warnings
3637
from typing import Optional, Tuple
3738

3839
import torch
3940
import torch.nn.functional as F
4041
from torch import nn
4142

4243

43-
@torch.jit.script
44+
def _compile(fn):
45+
"""Apply ``torch.compile`` when available, falling back to uncompiled.
46+
47+
``torch.compile`` is the recommended replacement for ``torch.jit.script``
48+
starting from PyTorch 2.0. In environments where the compiler backend is
49+
not available (e.g. certain CI configurations or incomplete installations),
50+
the function is returned unchanged so that correctness is preserved.
51+
A :class:`UserWarning` is emitted when the fallback path is taken.
52+
"""
53+
try:
54+
return torch.compile(fn)
55+
except (ImportError, RuntimeError, TypeError) as exc:
56+
warnings.warn(
57+
f"torch.compile is unavailable; falling back to uncompiled "
58+
f"{fn.__name__!r}. Reason: {exc}",
59+
UserWarning,
60+
stacklevel=2,
61+
)
62+
return fn
63+
64+
65+
@_compile
4466
def dot_similarity(ref: torch.Tensor, pos: torch.Tensor,
4567
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
4668
"""Cosine similarity the ref, pos and negative pairs
@@ -59,7 +81,7 @@ def dot_similarity(ref: torch.Tensor, pos: torch.Tensor,
5981
return pos_dist, neg_dist
6082

6183

62-
@torch.jit.script
84+
@_compile
6385
def euclidean_similarity(
6486
ref: torch.Tensor, pos: torch.Tensor,
6587
neg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -85,7 +107,7 @@ def euclidean_similarity(
85107
return pos_dist, neg_dist
86108

87109

88-
@torch.jit.script
110+
@_compile
89111
def infonce(
90112
pos_dist: torch.Tensor, neg_dist: torch.Tensor
91113
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

tests/test_criterions.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,19 @@
2525

2626
import cebra.models.criterions as cebra_criterions
2727

28+
# Use the same _compile helper from criterions for consistency
29+
_compile = cebra_criterions._compile
2830

29-
@torch.jit.script
31+
32+
@_compile
3033
def ref_dot_similarity(ref: torch.Tensor, pos: torch.Tensor, neg: torch.Tensor,
3134
temperature: float):
3235
pos_dist = torch.einsum("ni,ni->n", ref, pos) / temperature
3336
neg_dist = torch.einsum("ni,mi->nm", ref, neg) / temperature
3437
return pos_dist, neg_dist
3538

3639

37-
@torch.jit.script
40+
@_compile
3841
def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
3942
neg: torch.Tensor, temperature: float):
4043
ref_sq = torch.einsum("ni->n", ref**2) / temperature
@@ -48,7 +51,7 @@ def ref_euclidean_similarity(ref: torch.Tensor, pos: torch.Tensor,
4851
return pos_dist, neg_dist
4952

5053

51-
@torch.jit.script
54+
@_compile
5255
def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
5356
with torch.no_grad():
5457
c, _ = neg_dist.max(dim=1, keepdim=True)
@@ -61,7 +64,7 @@ def ref_infonce(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
6164
return align + uniform, align, uniform
6265

6366

64-
@torch.jit.script
67+
@_compile
6568
def ref_infonce_not_stable(pos_dist: torch.Tensor, neg_dist: torch.Tensor):
6669
pos_dist = pos_dist
6770
neg_dist = neg_dist

0 commit comments

Comments
 (0)