-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathyx.py
More file actions
160 lines (139 loc) · 5 KB
/
yx.py
File metadata and controls
160 lines (139 loc) · 5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Standalone function for plotting 1D y-vs-x data.
Replaces ``MatPlot1D.plot_yx`` / ``MatWrap`` system.
"""
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
from autoarray.plot.utils import apply_labels, conf_figsize, save_figure
def plot_yx(
y,
x=None,
ax: Optional[plt.Axes] = None,
# --- errors / extras --------------------------------------------------------
y_errors: Optional[np.ndarray] = None,
x_errors: Optional[np.ndarray] = None,
y_extra: Optional[np.ndarray] = None,
shaded_region: Optional[Tuple[np.ndarray, np.ndarray]] = None,
# --- cosmetics --------------------------------------------------------------
title: str = "",
xlabel: str = "",
ylabel: str = "",
xtick_suffix: str = "",
ytick_suffix: str = "",
label: Optional[str] = None,
color: str = "k",
linestyle: str = "-",
plot_axis_type: str = "linear",
# --- figure control (used only when ax is None) -----------------------------
figsize: Optional[Tuple[int, int]] = None,
output_path: Optional[str] = None,
output_filename: str = "yx",
output_format: str = "png",
) -> None:
"""
Plot 1D y versus x data.
Replaces ``MatPlot1D.plot_yx`` with direct matplotlib calls.
Parameters
----------
y
1D numpy array of y values.
x
1D numpy array of x values. When ``None`` integer indices are used.
ax
Existing ``Axes`` to draw onto. ``None`` creates a new figure.
y_errors, x_errors
Per-point error values; trigger ``plt.errorbar``.
y_extra
Optional second y series to overlay.
shaded_region
Tuple ``(y1, y2)`` arrays; filled region drawn with alpha.
title
Figure title.
xlabel, ylabel
Axis labels.
label
Legend label for the main series.
color
Line / marker colour.
linestyle
Line style string.
plot_axis_type
One of ``"linear"``, ``"log"``, ``"loglog"``, ``"symlog"``.
figsize
Figure size in inches.
output_path
Directory for saving. Empty / ``None`` calls ``plt.show()``.
output_filename
Base file name without extension.
output_format
File format, e.g. ``"png"``.
"""
# --- autoarray extraction --------------------------------------------------
if x is None and hasattr(y, "grid_radial"):
x = y.grid_radial
y = y.array if hasattr(y, "array") else np.asarray(y)
if x is not None:
x = x.array if hasattr(x, "array") else np.asarray(x)
# guard: nothing to draw
if y is None or len(y) == 0 or np.isnan(y).all():
return
owns_figure = ax is None
if owns_figure:
figsize = figsize or conf_figsize("figures")
fig, ax = plt.subplots(1, 1, figsize=figsize)
else:
fig = ax.get_figure()
if x is None:
x = np.arange(len(y))
# --- main line / scatter ---------------------------------------------------
if y_errors is not None or x_errors is not None:
ax.errorbar(
x,
y,
yerr=y_errors,
xerr=x_errors,
fmt="-o",
color=color,
label=label,
markersize=3,
)
elif plot_axis_type == "scatter":
ax.scatter(x, y, s=2, c=color, label=label)
elif plot_axis_type in ("log", "semilogy"):
ax.semilogy(x, y, color=color, linestyle=linestyle, label=label)
elif plot_axis_type == "loglog":
ax.loglog(x, y, color=color, linestyle=linestyle, label=label)
else:
ax.plot(x, y, color=color, linestyle=linestyle, label=label)
if plot_axis_type == "symlog":
ax.set_yscale("symlog")
# --- extras ----------------------------------------------------------------
if y_extra is not None:
ax.plot(x, y_extra, color="r", linestyle="--", alpha=0.7)
if shaded_region is not None:
y1, y2 = shaded_region
ax.fill_between(x, y1, y2, alpha=0.3)
# --- labels ----------------------------------------------------------------
apply_labels(ax, title=title, xlabel=xlabel, ylabel=ylabel, is_subplot=not owns_figure)
# --- 3-point ticks with optional unit suffixes ----------------------------
from autoarray.plot.utils import _inward_ticks, _round_ticks, _conf_ticks
factor = _conf_ticks("extent_factor_2d", 0.75)
xlo, xhi = ax.get_xlim()
ylo, yhi = ax.get_ylim()
xticks = _round_ticks(_inward_ticks(xlo, xhi, factor, 3))
yticks = _round_ticks(_inward_ticks(ylo, yhi, factor, 3))
ax.set_xticks(xticks)
ax.set_xticklabels([f"{v:g}{xtick_suffix}" for v in xticks])
ax.set_yticks(yticks)
ax.set_yticklabels([f"{v:g}{ytick_suffix}" for v in yticks])
if label is not None:
ax.legend(fontsize=12)
# --- output ----------------------------------------------------------------
if owns_figure:
save_figure(
fig,
path=output_path or "",
filename=output_filename,
format=output_format,
)