diff --git a/.gitignore b/.gitignore index da34449a..70e4cbd7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ CLAUDE.md # Profiling results traces/ +# Trajectories results +tests/trajectories/results/ + # uv uv.lock diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9e55d66a..4414f07e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -228,6 +228,39 @@ implementation of a mathematical aggregator. To deprecate some public functionality, make it raise a `DeprecationWarning`. A test should also be added in `tests/unit/test_deprecations.py`, ensuring that this warning is issued. +## Trajectories + +The `tests/trajectories/` directory contains scripts to generate and visualize optimization +trajectories using various aggregators on simple multi-objective problems. They require the `plot` +dependency group. + +Available objective keys: `EWQ`, `CQF`, `CQF2`, `HQF`, `MN2`, `MN20`. + +Available aggregator keys: `upgrad`, `mgda`, `cagrad`, `nashmtl`, `nashmtl20`, `graddrop`, +`imtl_g`, `aligned_mtl`, `dualproj`, `pcgrad`, `random`, `mean`. + +**Step 1 — Optimize:** run the optimization for an objective and a selection of aggregators: +```bash +uv run python tests/trajectories/optimize.py EWQ upgrad mean mgda cagrad dualproj graddrop imtl_g aligned_mtl nashmtl random +``` +This saves trajectory data under `tests/trajectories/results/` (gitignored). + +**Step 2 — Plot:** generate the plots from the saved trajectories: +```bash +export MPLBACKEND=Agg +uv run python tests/trajectories/plot_params.py EWQ +uv run python tests/trajectories/plot_values.py EWQ +uv run python tests/trajectories/plot_distance_to_pf.py EWQ +``` + +Replace `EWQ` with any other objective key. The three plot scripts produce PDFs saved to +`tests/trajectories/results//`. + +> [!NOTE] +> The plot scripts require a LaTeX installation for rendering: +> `sudo apt-get install texlive-latex-extra texlive-fonts-recommended dvipng cm-super` + + ## Release *This section is addressed to maintainers.* diff --git a/tests/trajectories/__init__.py b/tests/trajectories/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trajectories/_constants.py b/tests/trajectories/_constants.py new file mode 100644 index 00000000..b916b081 --- /dev/null +++ b/tests/trajectories/_constants.py @@ -0,0 +1,183 @@ +from math import cos, sin + +import numpy as np +import torch + +from torchjd._linalg import QuadprogProjector +from torchjd.aggregation import ( + IMTLG, + MGDA, + AlignedMTL, + CAGrad, + DualProj, + GradDrop, + Mean, + NashMTL, + PCGrad, + Random, + UPGrad, +) +from trajectories._objectives import ( + ConvexQuadraticForm, + ElementWiseQuadratic, + HomogenousQuadraticForm, + Multinorm, + QuadraticForm, +) + +AGGREGATORS = { + "upgrad": UPGrad(projector=QuadprogProjector(reg_eps=1e-7, norm_eps=1e-9)), + "mgda": MGDA(), + "cagrad": CAGrad(c=0.5), + "nashmtl": NashMTL(n_tasks=2, optim_niter=1), + "nashmtl20": NashMTL(n_tasks=20, optim_niter=1), + "graddrop": GradDrop(), + "imtl_g": IMTLG(), + "aligned_mtl": AlignedMTL(), + "dualproj": DualProj(projector=QuadprogProjector(reg_eps=1e-7, norm_eps=1e-9)), + "pcgrad": PCGrad(), + "random": Random(), + "mean": Mean(), +} +LR_MULTIPLIERS = { + "upgrad": 1.0, + "mgda": 2.0, + "cagrad": 1.0, + "nashmtl": 2.0, + "nashmtl20": 2.0, + "graddrop": 0.5, + "imtl_g": 1.0, + "aligned_mtl": 4.0, + "dualproj": 1.0, + "pcgrad": 0.5, + "random": 1.0, + "mean": 1.0, +} +# Some methods have optimal LRs that are very problem-specific. This allows overriding the LR +# per-problem. +LR_MULTIPLIER_OVERRIDES = { + "HQF": { + "nashmtl": 20.0, + "imtl_g": 2.0, + }, + "CQF": {"nashmtl": 0.5}, + "CQF2": {"nashmtl": 0.5}, +} +AGGREGATOR_ORDER = { + "upgrad": 9, + "mgda": 1, + "cagrad": 5, + "nashmtl": 7, + "nashmtl20": 7, + "graddrop": 3, + "imtl_g": 4, + "aligned_mtl": 8, + "dualproj": 2, + "random": 6, + "mean": 0, + # No location for PCGrad as it's equivalent to UPGrad with 2 objectives +} +LATEX_NAMES = { + "upgrad": r"$\mathcal A_{\mathrm{UPGrad}}$ (ours)", + "mgda": r"$\mathcal A_{\mathrm{MGDA}}$", + "cagrad": r"$\mathcal A_{\mathrm{CAGrad}}$", + "nashmtl": r"$\mathcal A_{\mathrm{Nash-MTL}}$", + "nashmtl20": r"$\mathcal A_{\mathrm{Nash-MTL}}$", + "graddrop": r"$\mathcal A_{\mathrm{GradDrop}}$", + "imtl_g": r"$\mathcal A_{\mathrm{IMTL-G}}$", + "aligned_mtl": r"$\mathcal A_{\mathrm{Aligned-MTL}}$", + "dualproj": r"$\mathcal A_{\mathrm{DualProj}}$", + "pcgrad": r"$\mathcal A_{\mathrm{PCGrad}}$", + "random": r"$\mathcal A_{\mathrm{RGW}}$", + "mean": r"$\mathcal A_{\mathrm{Mean}}$", +} + +# Sometimes we need to override the xlim and ylim of the value plot to zoom enough +PLOT_VALUES_LIMS = { + "CQF": { + "xlim": (-0.125, 2.625), + "ylim": (-0.425, 8.925), + }, + "CQF2": { + "xlim": (-0.125, 2.625), + "ylim": (-0.425, 8.925), + }, +} + +THETA = np.pi / 16 + +OBJECTIVES = { + "EWQ": ElementWiseQuadratic(2), + "CQF": ConvexQuadraticForm( + Bs=[ + torch.tensor([[cos(THETA), -sin(THETA)], [sin(THETA), cos(THETA)]]) + @ torch.diag(torch.tensor([1.0, 0.1])), + torch.tensor([[cos(THETA), sin(THETA)], [-sin(THETA), cos(THETA)]]) + @ torch.diag(torch.tensor([torch.sqrt(torch.tensor(3.0)), 0.1])), + ], + us=[torch.tensor([1.0, 0.0]), torch.tensor([-1.0, 0.0])], + ), + "CQF2": QuadraticForm( + As=[torch.tensor([[1.0, 0.2], [0.2, 0.05]]), torch.tensor([[3.0, -0.6], [-0.6, 0.2]])], + us=[torch.tensor([1.0, 0.0]), torch.tensor([-1.0, 0.0])], + ), + "HQF": HomogenousQuadraticForm( + A=torch.tensor([[2.0, -1.0], [-1.0, 2.0]]), + scales=torch.tensor([1.0, 10.0]), + us=[torch.tensor([1.0, 0.0]), torch.tensor([-10.0, 0.0])], + ), + "MN2": Multinorm(torch.tensor([1.0, 10.0])), + "MN20": Multinorm(torch.arange(1, 21)), +} +BASE_LEARNING_RATES = { + "EWQ": 0.075, + "CQF": 0.125, + "CQF2": 0.125, + "HQF": 0.005, + "MN2": 0.02, + "MN20": 0.005, +} +INITIAL_POINTS = { + "EWQ": [ + [3.0, -2.0], + [0.0, -3.0], + [-4.0, 4.0], + [-3.0, 4.0], + [-3.5, -0.75], + ], + "CQF": [ + [0.5, 0.5], + [-1.0, 7.0], + [0.0, 0.0], + [1.0, 6.0], + ], + "CQF2": [ + [0.5, 0.5], + [-0.3, 7.0], + [0.0, 0.0], + ], + "HQF": [ + [-6.0, 4.0], + [-3.0, -1.5], + [1.5, 2.0], + [2.5, 5.5], + ], + "MN2": [ + [0.0, 0.0], + [-5.0, 5.0], + [10.0, 5.0], + [10.0, 0.0], + [20.0, 0.0], + ], + "MN20": [ + [0.0] * 20, + ], +} +N_ITERS = { + "EWQ": 50, + "CQF": 200, + "CQF2": 200, + "HQF": 100, + "MN2": 50, + "MN20": 500, +} diff --git a/tests/trajectories/_objectives.py b/tests/trajectories/_objectives.py new file mode 100644 index 00000000..e0c7de61 --- /dev/null +++ b/tests/trajectories/_objectives.py @@ -0,0 +1,162 @@ +from abc import ABC, abstractmethod + +import torch +from torch import Tensor + + +class Objective(ABC): + def __init__(self, n_params: int, n_values: int) -> None: + self.n_params = n_params + self.n_values = n_values + + @abstractmethod + def __call__(self, x: Tensor) -> Tensor: + """Compute the value of the objective function at x. It has to be a vector.""" + + @abstractmethod + def jacobian(self, x: Tensor) -> Tensor: + """ + Compute the value of the Jacobian of the objective function at x. It is a matrix of shape + [n_values, n_params]. + """ + + def __str__(self) -> str: + """Return a string representation of the objective function.""" + return self.__class__.__name__ + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.n_values})" + + +class WithSPSMappingMixin(ABC): + """Mixin adding the possibility to get the Strong Pareto stationary mapping.""" + + class SPSMapping(ABC): + @abstractmethod + def __call__(self, w: Tensor) -> Tensor: + """ + Map a vector with (strictly) positive coordinates to the corresponding strongly pareto + stationary point. + """ + + @property + @abstractmethod + def sps_mapping(self) -> "WithSPSMappingMixin.SPSMapping": + pass + + +class QuadraticForm(Objective, WithSPSMappingMixin): + def __init__(self, As: list[Tensor], us: list[Tensor]) -> None: + if len(As) != len(us): + raise ValueError("As and us must have the same length.") + + if len(As) < 1: + raise ValueError("As and us must have at least one element.") + + super().__init__(n_params=len(us[0]), n_values=len(As)) + # Note that if A is not PSD, the objective is not convex. + self.As = As + self.us = us + + def __call__(self, x: Tensor) -> Tensor: + objective_values = [self.quad(x, A, u) for A, u in zip(self.As, self.us, strict=False)] + return torch.stack(objective_values) + + def jacobian(self, x: Tensor) -> Tensor: + return torch.vstack([2 * (x - u) @ A for A, u in zip(self.As, self.us, strict=False)]) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(As={self.As}, us={self.us})" + + @staticmethod + def quad(x: Tensor, A: Tensor, u: Tensor) -> Tensor: + x_minus_u = x - u + return x_minus_u @ A @ x_minus_u + + class SPSMapping(WithSPSMappingMixin.SPSMapping): + def __init__(self, As: list[Tensor], us: list[Tensor]) -> None: + self.As = As + self.us = us + + def __call__(self, w: Tensor) -> Tensor: + G = torch.stack([weight * A for weight, A in zip(w, self.As, strict=False)]).sum(dim=0) + b = torch.stack( + [weight * A @ u for weight, A, u in zip(w, self.As, self.us, strict=False)] + ).sum(dim=0) + return torch.linalg.lstsq(G, b, driver="gelsd").solution + + @property + def sps_mapping(self) -> "QuadraticForm.SPSMapping": + return self.SPSMapping(self.As, self.us) + + +class HomogenousQuadraticForm(QuadraticForm): + def __init__(self, A: Tensor, scales: Tensor, us: list[Tensor]) -> None: + self.A = A + self.scales = scales + As = [A * scale for scale in scales] + super().__init__(As=As, us=us) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(A={self.A}, scales={self.scales}, us={self.us})" + + +class ConvexQuadraticForm(QuadraticForm): + def __init__(self, Bs: list[Tensor], us: list[Tensor]) -> None: + self.Bs = Bs + super().__init__(As=[B @ B.T for B in self.Bs], us=us) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(Bs={self.Bs}, us={self.us})" + + +class ElementWiseQuadratic(Objective, WithSPSMappingMixin): + def __init__(self, n_dim: int) -> None: + super().__init__(n_params=n_dim, n_values=n_dim) + + def __call__(self, x: Tensor) -> Tensor: + if len(x) != self.n_values: + raise ValueError("x must have the same length as the number of values.") + return x**2 + + def jacobian(self, x: Tensor) -> Tensor: + return torch.diag(torch.stack([2 * x[0], 2 * x[1]])) + + class SPSMapping(WithSPSMappingMixin.SPSMapping): + def __init__(self, n_values: int) -> None: + self.n_values = n_values + + def __call__(self, w: Tensor) -> Tensor: # noqa: ARG002 + return torch.zeros(self.n_values) + + @property + def sps_mapping(self) -> "ElementWiseQuadratic.SPSMapping": + return self.SPSMapping(self.n_values) + + +class Multinorm(Objective, WithSPSMappingMixin): + def __init__(self, a: Tensor) -> None: + n = len(a) + super().__init__(n_params=n, n_values=n) + self.a = a + + def __call__(self, x: Tensor) -> Tensor: + if len(x) != self.n_values: + raise ValueError("x must have the same length as the number of values.") + + # f_i(x) = a_i * || x - a_i * e_i ||² + return self.a * torch.norm(x.expand(len(x), len(x)) - torch.diag(self.a), dim=1) ** 2 + + def jacobian(self, x: Tensor) -> Tensor: + return self.a * 2 * (x.expand(len(x), len(x)) - torch.diag(self.a)) + + class SPSMapping(WithSPSMappingMixin.SPSMapping): + def __init__(self, a: Tensor) -> None: + self.a = a + + def __call__(self, w: Tensor) -> Tensor: + return w * self.a + + @property + def sps_mapping(self) -> "Multinorm.SPSMapping": + return self.SPSMapping(self.a) diff --git a/tests/trajectories/_optimization.py b/tests/trajectories/_optimization.py new file mode 100644 index 00000000..e7c0568f --- /dev/null +++ b/tests/trajectories/_optimization.py @@ -0,0 +1,64 @@ +import numpy as np +import torch +from torch import Tensor +from torch.nn import functional as F + +from torchjd import autojac +from torchjd.aggregation import Aggregator +from trajectories._objectives import Objective + + +def optimize( + objective: Objective, + initial_x: Tensor, + aggregator: Aggregator, + lr: float, + n_iters: int, +) -> tuple[list[Tensor], list[Tensor]]: + xs = [] + ys = [] + x = initial_x.clone().requires_grad_() + optimizer = torch.optim.SGD([x], lr=lr) + for i in range(n_iters): + print(f"{i + 1}/{n_iters}", end="\r") + xs.append(x.detach().clone()) + y = objective(x) + ys.append(y.detach().clone()) + optimizer.zero_grad() + autojac.backward(y) + autojac.jac_to_grad([x], aggregator) + optimizer.step() + print(" " * 50, end="\r") + + return xs, ys + + +def compute_gradient_cosine_similarities( + objective: Objective, x0_min: float, x0_max: float, x1_min: float, x1_max: float, n: int +) -> Tensor: + if objective.n_values != 2: + raise ValueError("Objective should have 2 values.") + + x0_len = x0_max - x0_min + x0_start = x0_min + x0_len / (n * 2) + x0_end = x0_max - x0_len / (n * 2) + + x1_len = x1_max - x1_min + x1_start = x1_min + x1_len / (n * 2) + x1_end = x1_max - x1_len / (n * 2) + + x0s = np.linspace(x0_start, x0_end, n, dtype=np.float32) + x1s = np.linspace(x1_start, x1_end, n, dtype=np.float32) + + similarities = torch.zeros(n, n) + for i, x0 in enumerate(x0s): + for j, x1 in enumerate(x1s): + x = torch.tensor([x0, x1]) + similarities[i][j] = compute_cosine_similarity(objective, x) + + return similarities + + +def compute_cosine_similarity(objective: Objective, x: Tensor) -> Tensor: + J = objective.jacobian(x) + return F.cosine_similarity(J[0].unsqueeze(0), J[1].unsqueeze(0), eps=1e-19).squeeze() diff --git a/tests/trajectories/_pareto_utils.py b/tests/trajectories/_pareto_utils.py new file mode 100644 index 00000000..3af690e6 --- /dev/null +++ b/tests/trajectories/_pareto_utils.py @@ -0,0 +1,96 @@ +from collections.abc import Callable +from functools import partial + +import numpy as np +import torch +from torch import Tensor, vmap + +from trajectories._objectives import ElementWiseQuadratic, Objective, WithSPSMappingMixin + + +def compute_normalized_2d_pf_distances( + objective: Objective, y0_min: float, y0_max: float, y1_min: float, y1_max: float, n: int +) -> Tensor: + y0_len = y0_max - y0_min + y0_start = y0_min + y0_len / (n * 2) + y0_end = y0_max - y0_len / (n * 2) + + y1_len = y1_max - y1_min + y1_start = y1_min + y1_len / (n * 2) + y1_end = y1_max - y1_len / (n * 2) + + y0s = torch.linspace(y0_start, y0_end, n, dtype=torch.float32) + y1s = torch.linspace(y1_start, y1_end, n, dtype=torch.float32) + + Y0, Y1 = torch.meshgrid(y0s, y1s, indexing="ij") # shape: (n, n) + Y = torch.stack([Y0, Y1], dim=-1) + + pf_dist = make_2d_pf_distance_fn(objective) + distances = vmap(vmap(pf_dist))(Y) + + max_distance = torch.max(distances[distances.isfinite()]) + distances = distances / max_distance + distances[distances.isnan()] = -1.0 + + return distances + + +def sample_2d_spss(objective: Objective) -> Tensor: + assert objective.n_values == 2 and isinstance(objective, WithSPSMappingMixin) + + eps = 1e-5 + + n_samples = 1 if isinstance(objective, ElementWiseQuadratic) else 1000 + + sps_mapping = objective.sps_mapping + + ws_np = np.linspace([0 + eps, 1 - eps], [1 - eps, 0 + eps], n_samples) + ws = torch.tensor(ws_np) + sps_points = torch.stack([sps_mapping(w) for w in ws]) + return sps_points + + +def sample_2d_pf(objective: Objective) -> Tensor: + sps_points = sample_2d_spss(objective) + pf_points = torch.stack([objective(x) for x in sps_points]) + return pf_points + + +def make_2d_pf_distance_fn(objective: Objective) -> Callable[[Tensor], Tensor]: + pf_points = sample_2d_pf(objective) + + def compute_2d_pf_distance(pf_points: Tensor, y: Tensor) -> Tensor: + """Compute the distance from a point y to a piecewise-linear Pareto front. + + The Pareto front is approximated as a polyline: the ordered sequence of + ``pf_points`` defines consecutive line segments, and the distance returned + is the minimum Euclidean distance from ``y`` to any of those segments. + + For each segment [A, B] the closest point on the segment to ``y`` is found + via orthogonal projection: + + t = dot(y - A, B - A) / ||B - A||² + + ``t`` is clamped to [0, 1] so that the closest point is constrained to the + segment rather than the infinite line through A and B. This ensures correct + distances when ``y`` lies "outside" the extent of the front (i.e. beyond + either endpoint). + + :param pf_points: Pareto front points of shape ``(k, n)``, ordered along the + front. Adjacent points define the segments of the polyline. + :param y: Query point of shape ``(n,)`` whose distance to the front is sought. + """ + if len(pf_points) == 1: + return (y - pf_points[0]).norm() + + pf_first = pf_points[:-1, :] + pf_second = pf_points[1:, :] + d = pf_second - pf_first + t = ((y - pf_first) * d).sum(dim=1) / (d * d).sum(dim=1) + closest = pf_first + t.clamp(0, 1).unsqueeze(1) * d + + # Clamp at 0 so that points below the PF have a distance of 0 to it. + distances = torch.clamp(y - closest, min=0).norm(dim=1) + return torch.min(distances) + + return partial(compute_2d_pf_distance, pf_points) diff --git a/tests/trajectories/_paths.py b/tests/trajectories/_paths.py new file mode 100644 index 00000000..e92f0514 --- /dev/null +++ b/tests/trajectories/_paths.py @@ -0,0 +1,23 @@ +from pathlib import Path + +RESULTS_DIR = Path(__file__).parent / "results" + + +def get_params_dir(objective_key: str) -> Path: + return RESULTS_DIR / objective_key / "X" + + +def get_values_dir(objective_key: str) -> Path: + return RESULTS_DIR / objective_key / "Y" + + +def get_param_plots_dir(objective_key: str) -> Path: + return RESULTS_DIR / objective_key / "param_plots" + + +def get_value_plots_dir(objective_key: str) -> Path: + return RESULTS_DIR / objective_key / "value_plots" + + +def get_distance_to_pf_plots_dir(objective_key: str) -> Path: + return RESULTS_DIR / objective_key / "distance_to_pf" diff --git a/tests/trajectories/_plotters.py b/tests/trajectories/_plotters.py new file mode 100644 index 00000000..5af3bd95 --- /dev/null +++ b/tests/trajectories/_plotters.py @@ -0,0 +1,388 @@ +from abc import ABC, abstractmethod +from typing import TypeAlias + +import matplotlib.patheffects as pe +import numpy as np +from matplotlib import cm as cm, colors as mcolors, pyplot as plt +from numpy.lib.stride_tricks import sliding_window_view + +Color: TypeAlias = str | tuple[float, float, float] | tuple[float, float, float, float] + + +class Plotter(ABC): + """Abstract base class to modify a matplotlib Axes object.""" + + @abstractmethod + def __call__(self, ax: plt.Axes) -> None: + pass + + def __add__(self, other: "Plotter") -> "Plotter": + return MultiPlotter((self, other)) + + +class EmptyPlotter(Plotter): + """Plotter that does nothing""" + + def __call__(self, ax: plt.Axes) -> None: + pass + + +class MultiPlotter(Plotter): + """Plotter applying several plotters.""" + + def __init__(self, plotters: tuple["Plotter", ...]) -> None: + self.plotters = plotters + + def __call__(self, ax: plt.Axes) -> None: + for plotter in self.plotters: + plotter(ax) + + +class PointPlotter(Plotter, ABC): + """Abstract plotter storing a single point.""" + + def __init__(self, x: float, y: float) -> None: + self.x = x + self.y = y + + +class InitialPointPlotter(PointPlotter): + """PointPlotter that can draw the initial point.""" + + def __init__(self, x: float, y: float, color: Color = "black") -> None: + super().__init__(x, y) + self.color = color + + def __call__(self, ax: plt.Axes) -> None: + ax.scatter( + self.x, self.y, color=self.color, edgecolors="black", s=30, linewidth=0.7, zorder=3 + ) + + +class OptimalPointPlotter(PointPlotter): + """PointPlotter that can draw the optimal point.""" + + def __init__(self, x: float, y: float, color: Color) -> None: + super().__init__(x, y) + self.color = color + + def __call__(self, ax: plt.Axes) -> None: + ax.scatter( + self.x, + self.y, + marker="*", + color=self.color, + zorder=3, + s=60, + ) + + +class OptimalLinePlotter(Plotter): + """ + Plotter that can draw a continuous path with uniform color linking the provided optimal points. + """ + + def __init__(self, points: np.ndarray, color: Color) -> None: + self.points = points + self.color = color + + def __call__(self, ax: plt.Axes) -> None: + ax.plot(self.points[:, 0], self.points[:, 1], color=self.color, linewidth=2.0) + + +class AxesPlotter(Plotter): + """Plotter that can draw the x=0 and y=0 axes.""" + + def __call__(self, ax: plt.Axes) -> None: + ax.axhline(y=0, color="black", linewidth=0.75, alpha=0.5) + ax.axvline(x=0, color="black", linewidth=0.75, alpha=0.5) + + +class CirclePlotter(Plotter): + """Plotter that can draw a circle.""" + + def __init__(self, radius: float, color: Color) -> None: + self.radius = radius + self.color = color + + def __call__(self, ax: plt.Axes) -> None: + circle = plt.Circle( + (0, 0), + self.radius, + color=self.color, + fill=False, + linestyle="--", + alpha=0.5, + linewidth=1.5, + ) + ax.add_patch(circle) + + +class ContourCirclesPlotter(MultiPlotter): + """ + MultiPlotter that can draw several circles of different radii and colors, to make contour + lines centered at zero. + """ + + def __init__(self) -> None: + radiuses = [1.0, 2.5, 4, 5.5, 7, 8.5] + colormap = cm.inferno_r # ty:ignore[unresolved-attribute] + norm = mcolors.Normalize(vmin=-1, vmax=max(radiuses)) + plotters = tuple(CirclePlotter(radius, colormap(norm(radius))) for radius in radiuses) + super().__init__(plotters) + + +class SegmentPlotter(Plotter): + """Plotter that can draw a single segment of a given color.""" + + def __init__(self, xp: np.ndarray, yp: np.ndarray, color: Color) -> None: + self.xp = xp + self.yp = yp + self.color = color + + def __call__(self, ax: plt.Axes) -> None: + ax.plot(self.xp, self.yp, color=self.color, solid_capstyle="round", linewidth=1.5) + + +class PathPlotter(MultiPlotter): + """Plotter that can draw a path of segments with colors varying along a gradient.""" + + def __init__(self, points: np.ndarray) -> None: + x_view = sliding_window_view(points[:, 0], window_shape=2) + y_view = sliding_window_view(points[:, 1], window_shape=2) + + colors = PathPlotter._get_color_gradient("#FF0000", "#FFEE00", len(points) - 1) + plotters = tuple( + SegmentPlotter(xp, yp, color) + for xp, yp, color in zip(x_view, y_view, colors, strict=False) + ) + super().__init__(plotters) + + @staticmethod + def _get_color_gradient(c1: str, c2: str, n: int) -> list[str]: + """Given two hex colors, returns a color gradient with n colors.""" + + assert n > 1 + c1_rgb = np.array(PathPlotter._hex_to_rgb(c1)) / 255 + c2_rgb = np.array(PathPlotter._hex_to_rgb(c2)) / 255 + mix_pcts = [x / (n - 1) for x in range(n)] + rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts] + return [ + "#" + "".join([format(round(val * 255), "02x") for val in item]) for item in rgb_colors + ] + + @staticmethod + def _hex_to_rgb(hex_str: str) -> list[int]: + """Map a hex color string to an [R, G, B] list of ints.""" + return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)] + + +class TrajPlotter(MultiPlotter): + """Plotter that can draw a trajectory: initial point + path.""" + + def __init__(self, points: np.ndarray, initial_point_color: Color) -> None: + x = points[0, 0] + y = points[0, 1] + plotters = (InitialPointPlotter(x, y, initial_point_color), PathPlotter(points)) + super().__init__(plotters) + + +class MultiTrajPlotter(MultiPlotter): + """Plotter that can draw several trajectories (one for each initial point).""" + + CMAP = plt.get_cmap("Set2") + + def __init__(self, points_matrix: np.ndarray) -> None: + plotters = tuple( + TrajPlotter(points, self.CMAP(i)) for i, points in enumerate(points_matrix) + ) + super().__init__(plotters) + + +class EvolutionPlotter(Plotter): + """Plotter that can draw an evolution over the discrete timesteps.""" + + def __init__(self, values: np.ndarray, color: Color) -> None: + self.x = np.arange(len(values)) + 1 + self.y = values + self.color = color + + def __call__(self, ax: plt.Axes) -> None: + (line,) = ax.plot(self.x, self.y, color=self.color, linewidth=1.5) + + # Add thin black outline around the lines + line.set_path_effects( + [ + pe.Stroke(linewidth=2.1, foreground="black"), # outline + pe.Normal(), # original line on top + ] + ) + ax.grid(linewidth=0.5) + + +class MultiEvolutionPlotter(MultiPlotter): + """ + Plotter that can draw the evolution of some value over timestamps for each initial point. + """ + + CMAP = plt.get_cmap("Set2") + + def __init__(self, values_vector: np.ndarray) -> None: + plotters = tuple( + EvolutionPlotter(values, self.CMAP(i)) for i, values in enumerate(values_vector) + ) + super().__init__(plotters) + + +class SetPlotter(Plotter): + """ + Plotter that can represent an optimal set. + + If the provided array of optimal points contains a single point, the set will be represented by + a star. Otherwise, it will be represented as a connected line plot. This does not necessarily + work for all optimal sets, but it should be fine for those that are convex. + """ + + def __init__(self, points: np.ndarray, color: Color) -> None: + self.points = points + self.color = color + + if len(points) == 1: + self.plotter: Plotter = OptimalPointPlotter(points[0, 0], points[0, 1], color=color) + else: + self.plotter = OptimalLinePlotter(points, color=color) + + def __call__(self, ax: plt.Axes) -> None: + self.plotter(ax) + + +class SPSPlotter(SetPlotter): + """Plotter that can represent the Strong Pareto stationary set: black SetPlotter""" + + def __init__(self, sps_points: np.ndarray) -> None: + super().__init__(points=sps_points, color="#282828") + + +class PFPlotter(SetPlotter): + """Plotter that can represent the Pareto front: black SetPlotter""" + + def __init__(self, pf_points: np.ndarray) -> None: + super().__init__(points=pf_points, color="#282828") + + +class HeatmapPlotter(Plotter): + """ + Plotter that can draw a heatmap with the given values extending between the provided + coordinates. + """ + + def __init__( + self, + values: np.ndarray, + x_min: float, + x_max: float, + y_min: float, + y_max: float, + vmin: float, + vmax: float, + cmap: str, + ) -> None: + self.values = values + self.x_min = x_min + self.x_max = x_max + self.y_min = y_min + self.y_max = y_max + self.cmap = cmap + self.vmin = vmin + self.vmax = vmax + + def __call__(self, ax: plt.Axes) -> None: + ax.imshow( + self.values.T, + origin="lower", + cmap=self.cmap, + aspect="auto", + vmin=self.vmin, + vmax=self.vmax, + extent=(self.x_min, self.x_max, self.y_min, self.y_max), + alpha=0.4, + interpolation="bicubic", + ) + + +class LimAdjuster(Plotter): + """Plotter that adjusts the xlim and ylim of the plot to the specified xlim and ylim.""" + + def __init__(self, xlim: tuple[float, float], ylim: tuple[float, float]) -> None: + self.xlim = xlim + self.ylim = ylim + + def __call__(self, ax: plt.Axes) -> None: + ax.set_xlim(self.xlim) + ax.set_ylim(self.ylim) + + +class ContentLimAdjuster(LimAdjuster): + """Plotter that adjusts the xlim and ylim of the plot to the coordinates of the content.""" + + def __init__(self, content: np.ndarray) -> None: + x_min, y_min = content.min(axis=0) + x_max, y_max = content.max(axis=0) + x_range = x_max - x_min + y_range = y_max - y_min + margin = 0.05 + super().__init__( + xlim=(x_min - margin * x_range, x_max + margin * x_range), + ylim=(y_min - margin * y_range, y_max + margin * y_range), + ) + + +class XTicksClearer(Plotter): + """Plotter that hides the xticks.""" + + def __call__(self, ax: plt.Axes) -> None: + ax.tick_params(bottom=False, labelbottom=False) + + +class YTicksClearer(Plotter): + """Plotter that hides the yticks.""" + + def __call__(self, ax: plt.Axes) -> None: + ax.tick_params(left=False, labelleft=False) + + +class XAxisLabeller(Plotter): + """Plotter that labels the x-axis.""" + + def __init__(self, xlabel: str) -> None: + self.xlabel = xlabel + + def __call__(self, ax: plt.Axes) -> None: + ax.set_xlabel(self.xlabel) + + +class YAxisLabeller(Plotter): + """Plotter that labels the y-axis.""" + + def __init__(self, ylabel: str) -> None: + self.ylabel = ylabel + + def __call__(self, ax: plt.Axes) -> None: + ax.set_ylabel(self.ylabel) + + +class TitleSetter(Plotter): + """Plotter that sets the title.""" + + def __init__(self, title: str) -> None: + self.title = title + + def __call__(self, ax: plt.Axes) -> None: + ax.set_title(self.title) + + +class SquareBoxAspectSetter(Plotter): + """Plotter that sets a square box aspect.""" + + def __call__(self, ax: plt.Axes) -> None: + ax.set_box_aspect(1) diff --git a/tests/trajectories/_plotting_utils.py b/tests/trajectories/_plotting_utils.py new file mode 100644 index 00000000..f326a485 --- /dev/null +++ b/tests/trajectories/_plotting_utils.py @@ -0,0 +1,83 @@ +"""Utility functions for plotting.""" + + +def map_orders_to_indices( + aggregator_keys: list[str], aggregator_order: dict[str, int] +) -> dict[str, int]: + """Map aggregator keys to their indices based on sorted order. + + This function takes the available aggregators and maps them to sequential indices + (0, 1, 2, ...) based on their order values. This ensures that subplots are positioned + correctly regardless of which aggregators are actually present. + + :param aggregator_keys: List of aggregator keys to map. + :param aggregator_order: Dictionary mapping aggregator keys to their order values. + + Example: if ``aggregator_keys = ["mean", "dualproj", "aligned_mtl"]`` with orders + ``[0, 2, 8]``, this returns ``{"mean": 0, "dualproj": 1, "aligned_mtl": 2}``. + """ + sorted_keys = sorted(aggregator_keys, key=lambda k: aggregator_order[k]) + return {key: idx for idx, key in enumerate(sorted_keys)} + + +def compute_subplot_layout(n_aggregators: int) -> tuple[int, int]: + """Compute subplot layout (n_rows, n_cols) based on number of aggregators. + + :param n_aggregators: Number of aggregators to plot. + :raises ValueError: If n_aggregators is not between 1 and 10. + """ + if n_aggregators <= 5: + return 1, n_aggregators + if n_aggregators == 6: + return 2, 3 + if n_aggregators == 7 or n_aggregators == 8: + return 2, 4 + if n_aggregators == 9 or n_aggregators == 10: + return 2, 5 + raise ValueError(f"Unsupported number of aggregators: {n_aggregators}") + + +def get_subplot_position( + order: int, n_aggregators: int, n_rows: int, n_cols: int +) -> tuple[int, int]: + """Convert order index to (row, col) position. + + :param order: The order index of the aggregator. + :param n_aggregators: Total number of aggregators. + :param n_rows: Number of rows in the subplot grid. + :param n_cols: Number of columns in the subplot grid. + """ + if n_rows == 1: + return 0, order + # For 2 rows: n_cols equals both n_aggregators//2 (even split) and (n_aggregators+1)//2 (odd). + if n_aggregators in [6, 8, 10]: + return order // n_cols, order % n_cols + if n_aggregators in [7, 9]: + if order < n_cols: + return 0, order + return 1, order - n_cols + raise ValueError(f"Unsupported combination of n_aggregators={n_aggregators}, n_rows={n_rows}") + + +def get_unused_subplot_positions( + n_aggregators: int, n_rows: int, n_cols: int +) -> list[tuple[int, int]]: + """Get list of unused subplot positions. + + For layouts where not all subplot positions are used (e.g., 7 aggregators in a 2x4 grid), + this function returns the positions that should be hidden. + + :param n_aggregators: Total number of aggregators. + :param n_rows: Number of rows in the subplot grid. + :param n_cols: Number of columns in the subplot grid. + """ + if n_rows == 1: + return [] + + used_positions = set() + for idx in range(n_aggregators): + pos = get_subplot_position(idx, n_aggregators, n_rows, n_cols) + used_positions.add(pos) + + all_positions = [(i, j) for i in range(n_rows) for j in range(n_cols)] + return [pos for pos in all_positions if pos not in used_positions] diff --git a/tests/trajectories/optimize.py b/tests/trajectories/optimize.py new file mode 100644 index 00000000..8f0b9f34 --- /dev/null +++ b/tests/trajectories/optimize.py @@ -0,0 +1,129 @@ +""" +Optimize the objective using various aggregators. Save the trajectories in the parameter and value +spaces. + +Usage: + uv run python tests/trajectories/optimize.py ... + +Arguments: + The key of the objective function (e.g., EWQ, CQF, HQF, MN2, MN20). + ... The keys of the aggregators to use (e.g., upgrad, mean, mgda). +""" + +import argparse +import json +import random +import warnings + +import numpy as np +import torch + +from torchjd.aggregation import Stateful +from trajectories._constants import ( + AGGREGATORS, + BASE_LEARNING_RATES, + INITIAL_POINTS, + LR_MULTIPLIER_OVERRIDES, + LR_MULTIPLIERS, + N_ITERS, + OBJECTIVES, +) +from trajectories._optimization import optimize +from trajectories._paths import RESULTS_DIR, get_params_dir, get_values_dir + +warnings.filterwarnings("ignore") + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Optimize the objective using various aggregators. Save the trajectories in the" + " parameter and value spaces." + ) + ) + parser.add_argument( + "objective", + help=f"Key of the objective function. Choices: {list(OBJECTIVES)}", + ) + parser.add_argument( + "aggregators", + nargs="+", + metavar="aggregator", + help=f"Keys of the aggregators to use. Choices: {list(AGGREGATORS)}", + ) + args = parser.parse_args() + + objective_key = args.objective + if objective_key not in OBJECTIVES: + raise ValueError(f"Unknown objective key: {objective_key}") + + aggregator_keys = args.aggregators + for aggregator_key in aggregator_keys: + if aggregator_key not in AGGREGATORS: + raise ValueError(f"Unknown aggregator key: {aggregator_key}") + + objective = OBJECTIVES[objective_key] + initial_points = INITIAL_POINTS[objective_key] + + learning_rates = {} + lr_multiplier_overrides = LR_MULTIPLIER_OVERRIDES.get(objective_key, {}) + base_lr = BASE_LEARNING_RATES[objective_key] + for key in LR_MULTIPLIERS: + mult = lr_multiplier_overrides.get(key, LR_MULTIPLIERS[key]) + learning_rates[key] = mult * base_lr + n_iters = N_ITERS[objective_key] + + torch.use_deterministic_algorithms(True) + + params_dir = get_params_dir(objective_key) + values_dir = get_values_dir(objective_key) + params_dir.mkdir(exist_ok=True, parents=True) + values_dir.mkdir(exist_ok=True, parents=True) + + metadata = { + "objective_key": objective_key, + "objective_repr": repr(objective), + "aggregator_keys": aggregator_keys, + "aggregator_reprs": {key: repr(AGGREGATORS[key]) for key in aggregator_keys}, + "learning_rates": learning_rates, + "initial_points": initial_points, + } + with open(RESULTS_DIR / objective_key / "metadata.json", "w") as f: + json.dump(metadata, f) + + for aggregator_key in aggregator_keys: + aggregator = AGGREGATORS[aggregator_key] + lr = learning_rates[aggregator_key] + print(aggregator) + xs_list = [] + ys_list = [] + for initial_point in initial_points: + print(initial_point) + + if isinstance(aggregator, Stateful): + aggregator.reset() + _reset_seed() + + initial_x = torch.tensor(initial_point) + xs, ys = optimize( + objective, initial_x=initial_x, aggregator=aggregator, lr=lr, n_iters=n_iters + ) + + xs_list.append(torch.stack(xs)) + ys_list.append(torch.stack(ys)) + + X = torch.stack(xs_list).numpy() + Y = torch.stack(ys_list).numpy() + np.save(params_dir / f"{aggregator_key}.npy", X) + np.save(values_dir / f"{aggregator_key}.npy", Y) + print() + + +def _reset_seed() -> None: + torch.manual_seed(0) + np.random.seed(0) + random.seed(0) + + +if __name__ == "__main__": + main() diff --git a/tests/trajectories/plot_distance_to_pf.py b/tests/trajectories/plot_distance_to_pf.py new file mode 100644 index 00000000..01c060ec --- /dev/null +++ b/tests/trajectories/plot_distance_to_pf.py @@ -0,0 +1,116 @@ +""" +Plot the evolution of the distance to the Pareto front of an objective function over time. + +Usage: + uv run python tests/trajectories/plot_distance_to_pf.py + +Arguments: + The key of the objective function (e.g., EWQ, CQF, HQF, MN2, MN20). +""" + +import argparse +import json + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from trajectories._constants import AGGREGATOR_ORDER, AGGREGATORS, LATEX_NAMES, OBJECTIVES +from trajectories._pareto_utils import make_2d_pf_distance_fn +from trajectories._paths import RESULTS_DIR, get_distance_to_pf_plots_dir, get_values_dir +from trajectories._plotters import ( + MultiEvolutionPlotter, + SquareBoxAspectSetter, + TitleSetter, + XAxisLabeller, + XTicksClearer, + YAxisLabeller, + YTicksClearer, +) +from trajectories._plotting_utils import ( + compute_subplot_layout, + get_subplot_position, + get_unused_subplot_positions, + map_orders_to_indices, +) + + +def main() -> None: + print("Plotting distance to Pareto front...") + + parser = argparse.ArgumentParser( + description=( + "Plot the evolution of the distance to the Pareto front of an objective function" + " over time." + ) + ) + parser.add_argument( + "objective", + help=f"Key of the objective function. Choices: {list(OBJECTIVES)}", + ) + args = parser.parse_args() + objective_key = args.objective + + with open(RESULTS_DIR / objective_key / "metadata.json") as f: + metadata = json.load(f) + + values_dir = get_values_dir(objective_key) + distance_to_pf_plots_dir = get_distance_to_pf_plots_dir(objective_key) + distance_to_pf_plots_dir.mkdir(parents=True, exist_ok=True) + + # This seems to be the only way to make the font be Type1, which is the only font type supported + # by ICML. + plt.rcParams.update({"text.usetex": True}) + objective_key = metadata["objective_key"] + objective = OBJECTIVES[objective_key] + + common_plotter = SquareBoxAspectSetter() + + aggregator_keys = metadata["aggregator_keys"] + aggregator_to_Y = {key: np.load(values_dir / f"{key}.npy") for key in aggregator_keys} + + n_aggregators = len(aggregator_keys) + n_rows, n_cols = compute_subplot_layout(n_aggregators) + key_to_index = map_orders_to_indices(aggregator_keys, AGGREGATOR_ORDER) + fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2.5), sharey="all") + if n_rows == n_cols == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes.reshape(1, -1) + + unused_positions = get_unused_subplot_positions(n_aggregators, n_rows, n_cols) + for i, j in unused_positions: + axes[i][j].axis("off") + + save_path = distance_to_pf_plots_dir / "all.pdf" + pf_distance_fn = make_2d_pf_distance_fn(objective) + + for aggregator_key, Y in aggregator_to_Y.items(): + aggregator = AGGREGATORS[aggregator_key] + print(aggregator) + + # Y has shape [n_initial_points, n_iter, n_values] + pfd = torch.vmap(torch.vmap(pf_distance_fn))(torch.from_numpy(Y).to(dtype=torch.float64)) + + index = key_to_index[aggregator_key] + i, j = get_subplot_position(index, n_aggregators, n_rows, n_cols) + + plotter = ( + common_plotter + + MultiEvolutionPlotter(pfd.numpy()) + + TitleSetter(LATEX_NAMES[aggregator_key]) + ) + plotter += XAxisLabeller("Iteration") if i == n_rows - 1 else XTicksClearer() + plotter += YAxisLabeller("Distance to Pareto front") if j == 0 else YTicksClearer() + + plotter(axes[i][j]) + + fig.tight_layout(h_pad=-2.5) + + print("Saving figure") + plt.savefig(save_path, bbox_inches="tight") + print() + + +if __name__ == "__main__": + main() diff --git a/tests/trajectories/plot_params.py b/tests/trajectories/plot_params.py new file mode 100644 index 00000000..190d8dbe --- /dev/null +++ b/tests/trajectories/plot_params.py @@ -0,0 +1,157 @@ +""" +Plot the trajectories of an objective function in the parameter space. + +Usage: + uv run python tests/trajectories/plot_params.py + +Arguments: + The key of the objective function (e.g., EWQ, CQF, HQF, MN2, MN20). +""" + +import argparse +import json + +import matplotlib.pyplot as plt +import numpy as np + +from trajectories._constants import ( + AGGREGATOR_ORDER, + AGGREGATORS, + INITIAL_POINTS, + LATEX_NAMES, + OBJECTIVES, +) +from trajectories._objectives import ElementWiseQuadratic, WithSPSMappingMixin +from trajectories._optimization import compute_gradient_cosine_similarities +from trajectories._pareto_utils import sample_2d_spss +from trajectories._paths import RESULTS_DIR, get_param_plots_dir, get_params_dir +from trajectories._plotters import ( + AxesPlotter, + ContentLimAdjuster, + ContourCirclesPlotter, + HeatmapPlotter, + LimAdjuster, + MultiTrajPlotter, + SPSPlotter, + SquareBoxAspectSetter, + TitleSetter, + XAxisLabeller, + XTicksClearer, + YAxisLabeller, + YTicksClearer, +) +from trajectories._plotting_utils import ( + compute_subplot_layout, + get_subplot_position, + get_unused_subplot_positions, + map_orders_to_indices, +) + + +def main() -> None: + print("Plotting in parameter space...") + + parser = argparse.ArgumentParser( + description="Plot the trajectories of an objective function in the parameter space." + ) + parser.add_argument( + "objective", + help=f"Key of the objective function. Choices: {list(OBJECTIVES)}", + ) + args = parser.parse_args() + objective_key = args.objective + + with open(RESULTS_DIR / objective_key / "metadata.json") as f: + metadata = json.load(f) + + params_dir = get_params_dir(objective_key) + param_plots_dir = get_param_plots_dir(objective_key) + param_plots_dir.mkdir(parents=True, exist_ok=True) + + # This seems to be the only way to make the font be Type1, which is the only font type supported + # by ICML. + plt.rcParams.update({"text.usetex": True}) + objective_key = metadata["objective_key"] + objective = OBJECTIVES[objective_key] + + if objective.n_params != 2: + raise ValueError("Can only plot param trajectories for objectives with 2 params.") + + initial_points = INITIAL_POINTS[objective_key] + initial_points_array = np.stack([np.array(point) for point in initial_points]) + main_content = initial_points_array + + common_plotter = SquareBoxAspectSetter() + + if objective.n_values == 2 and isinstance(objective, WithSPSMappingMixin): + sps_points = sample_2d_spss(objective).numpy() + main_content = np.concatenate([main_content, sps_points]) + common_plotter += SPSPlotter(sps_points) + + if isinstance(objective, ElementWiseQuadratic): + common_plotter += AxesPlotter() + common_plotter += ContourCirclesPlotter() + common_plotter += LimAdjuster(xlim=(-5.0, 5.0), ylim=(-5.0, 5.0)) + else: + adjust_plotter = ContentLimAdjuster(main_content) + common_plotter += adjust_plotter + + if objective.n_values == 2: + similarities = compute_gradient_cosine_similarities( + objective, + x0_min=adjust_plotter.xlim[0], + x0_max=adjust_plotter.xlim[1], + x1_min=adjust_plotter.ylim[0], + x1_max=adjust_plotter.ylim[1], + n=200, + ) + common_plotter += HeatmapPlotter( + values=similarities.numpy() ** 3, + x_min=adjust_plotter.xlim[0], + x_max=adjust_plotter.xlim[1], + y_min=adjust_plotter.ylim[0], + y_max=adjust_plotter.ylim[1], + vmin=-1, + vmax=1, + cmap="PiYG", + ) + + aggregator_keys = metadata["aggregator_keys"] + aggregator_to_X = {key: np.load(params_dir / f"{key}.npy") for key in aggregator_keys} + + n_aggregators = len(aggregator_keys) + n_rows, n_cols = compute_subplot_layout(n_aggregators) + key_to_index = map_orders_to_indices(aggregator_keys, AGGREGATOR_ORDER) + fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2.5)) + if n_rows == n_cols == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes.reshape(1, -1) + + unused_positions = get_unused_subplot_positions(n_aggregators, n_rows, n_cols) + for i, j in unused_positions: + axes[i][j].axis("off") + + save_path = param_plots_dir / "all.pdf" + + for aggregator_key, X in aggregator_to_X.items(): + aggregator = AGGREGATORS[aggregator_key] + print(aggregator) + + index = key_to_index[aggregator_key] + i, j = get_subplot_position(index, n_aggregators, n_rows, n_cols) + + plotter = common_plotter + MultiTrajPlotter(X) + TitleSetter(LATEX_NAMES[aggregator_key]) + plotter += XAxisLabeller("$x_1$") if i == n_rows - 1 else XTicksClearer() + plotter += YAxisLabeller("$x_2$") if j == 0 else YTicksClearer() + + plotter(axes[i][j]) + + fig.tight_layout(h_pad=-2.5) + print("Saving figure") + plt.savefig(save_path, bbox_inches="tight") + print() + + +if __name__ == "__main__": + main() diff --git a/tests/trajectories/plot_values.py b/tests/trajectories/plot_values.py new file mode 100644 index 00000000..22b5f8a3 --- /dev/null +++ b/tests/trajectories/plot_values.py @@ -0,0 +1,155 @@ +""" +Plot the trajectories of an objective function in the value space. + +Usage: + uv run python tests/trajectories/plot_values.py + +Arguments: + The key of the objective function (e.g., EWQ, CQF, HQF, MN2, MN20). +""" + +import argparse +import json + +import matplotlib.pyplot as plt +import numpy as np + +from trajectories._constants import ( + AGGREGATOR_ORDER, + AGGREGATORS, + LATEX_NAMES, + OBJECTIVES, + PLOT_VALUES_LIMS, +) +from trajectories._objectives import WithSPSMappingMixin +from trajectories._pareto_utils import compute_normalized_2d_pf_distances, sample_2d_pf +from trajectories._paths import RESULTS_DIR, get_value_plots_dir, get_values_dir +from trajectories._plotters import ( + ContentLimAdjuster, + HeatmapPlotter, + LimAdjuster, + MultiTrajPlotter, + PFPlotter, + SquareBoxAspectSetter, + TitleSetter, + XAxisLabeller, + XTicksClearer, + YAxisLabeller, + YTicksClearer, +) +from trajectories._plotting_utils import ( + compute_subplot_layout, + get_subplot_position, + get_unused_subplot_positions, + map_orders_to_indices, +) + + +def main() -> None: + print("Plotting in value space...") + + parser = argparse.ArgumentParser( + description="Plot the trajectories of an objective function in the value space." + ) + parser.add_argument( + "objective", + help=f"Key of the objective function. Choices: {list(OBJECTIVES)}", + ) + args = parser.parse_args() + objective_key = args.objective + + with open(RESULTS_DIR / objective_key / "metadata.json") as f: + metadata = json.load(f) + + values_dir = get_values_dir(objective_key) + value_plots_dir = get_value_plots_dir(objective_key) + value_plots_dir.mkdir(parents=True, exist_ok=True) + + # This seems to be the only way to make the font be Type1, which is the only font type supported + # by ICML. + plt.rcParams.update({"text.usetex": True}) + objective_key = metadata["objective_key"] + objective = OBJECTIVES[objective_key] + + if objective.n_values != 2: + raise ValueError("Can only plot values trajectories for objectives with 2 values.") + + common_plotter = SquareBoxAspectSetter() + aggregator_keys = metadata["aggregator_keys"] + aggregator_to_Y = {key: np.load(values_dir / f"{key}.npy") for key in aggregator_keys} + first_agg_Y = next(iter(aggregator_to_Y.values())) + initial_values = first_agg_Y[:, 0, :] + main_content = initial_values + + if isinstance(objective, WithSPSMappingMixin): + pf_points_array = sample_2d_pf(objective).numpy() + common_plotter += PFPlotter(pf_points_array) + main_content = np.concatenate([main_content, pf_points_array]) + + if objective_key in PLOT_VALUES_LIMS: + lims = PLOT_VALUES_LIMS[objective_key] + xlim = lims["xlim"] + ylim = lims["ylim"] + common_plotter += LimAdjuster(xlim=xlim, ylim=ylim) + else: + adjust_plotter = ContentLimAdjuster(main_content) + common_plotter += adjust_plotter + xlim = adjust_plotter.xlim + ylim = adjust_plotter.ylim + + if isinstance(objective, WithSPSMappingMixin): + distances = compute_normalized_2d_pf_distances( + objective, + y0_min=xlim[0], + y0_max=xlim[1], + y1_min=ylim[0], + y1_max=ylim[1], + n=200, + ) + common_plotter += HeatmapPlotter( + values=distances.numpy(), + x_min=xlim[0], + x_max=xlim[1], + y_min=ylim[0], + y_max=ylim[1], + vmin=0, + vmax=1, + cmap="Reds", + ) + + n_aggregators = len(aggregator_keys) + n_rows, n_cols = compute_subplot_layout(n_aggregators) + key_to_index = map_orders_to_indices(aggregator_keys, AGGREGATOR_ORDER) + fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2.5)) + if n_rows == n_cols == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes.reshape(1, -1) + + unused_positions = get_unused_subplot_positions(n_aggregators, n_rows, n_cols) + for i, j in unused_positions: + axes[i][j].axis("off") + + save_path = value_plots_dir / "all.pdf" + + for aggregator_key, Y in aggregator_to_Y.items(): + aggregator = AGGREGATORS[aggregator_key] + print(aggregator) + + index = key_to_index[aggregator_key] + i, j = get_subplot_position(index, n_aggregators, n_rows, n_cols) + + plotter = common_plotter + MultiTrajPlotter(Y) + TitleSetter(LATEX_NAMES[aggregator_key]) + plotter += XAxisLabeller("Objective $1$") if i == n_rows - 1 else XTicksClearer() + plotter += YAxisLabeller("Objective $2$") if j == 0 else YTicksClearer() + + plotter(axes[i][j]) + + fig.tight_layout(h_pad=-2.5) + print("Saving figure") + plt.savefig(save_path, bbox_inches="tight") + print() + + +if __name__ == "__main__": + main()