Skip to content

Commit d5763ce

Browse files
committed
Add interactive fitting example with Gaussian components and widgets.
1 parent 3198c31 commit d5763ce

6 files changed

Lines changed: 1305 additions & 55 deletions

File tree

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
"""
2+
Interactive 1-D Gaussian Fitting
3+
=================================
4+
5+
A noisy composite signal built from two Gaussians is displayed. Two
6+
additional overlay lines show the individual **component** curves and a
7+
white **sum** curve that always equals the current manual model.
8+
9+
**Interaction**
10+
11+
Click any coloured component line to reveal its control widgets:
12+
13+
* **Circular handle** — drag to move the peak centre (μ) and amplitude (A).
14+
* **Shaded range** — drag either edge to widen or narrow the width (σ).
15+
16+
The sum curve updates on every drag frame.
17+
Press **f** (with the plot canvas focused) to run a least-squares fit.
18+
The components — and all active widgets — will snap to the fitted values,
19+
and the sum curve will jump to the optimal fit.
20+
Click a component line again to hide its widgets.
21+
"""
22+
import numpy as np
23+
from scipy.optimize import curve_fit
24+
import anyplotlib as apl
25+
26+
# ── Gaussian helpers ───────────────────────────────────────────────────────
27+
28+
def gaussian(x, amp, mu, sigma):
29+
return amp * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
30+
31+
# Half-width at half-maximum = sigma * _FWHM_K (full FWHM = 2 * sigma * _FWHM_K)
32+
_FWHM_K = np.sqrt(2.0 * np.log(2.0))
33+
34+
# ── Data ───────────────────────────────────────────────────────────────────
35+
36+
x = np.linspace(0, 10, 500)
37+
38+
TRUE_P = [
39+
dict(amp=1.0, mu=3.2, sigma=0.55),
40+
dict(amp=0.75, mu=6.8, sigma=0.80),
41+
]
42+
COLORS = ["#ff6b6b", "#69db7c"]
43+
44+
rng = np.random.default_rng(42)
45+
signal = sum(gaussian(x, **p) for p in TRUE_P) + rng.normal(0, 0.03, len(x))
46+
47+
# Initial component guesses (slightly off from truth)
48+
INIT_P = [
49+
dict(amp=1.0, mu=3.0, sigma=0.6),
50+
dict(amp=0.7, mu=7.0, sigma=0.9),
51+
]
52+
53+
# ── Figure ─────────────────────────────────────────────────────────────────
54+
55+
fig, ax = apl.subplots(1, 1, figsize=(720, 380))
56+
plot = ax.plot(signal, axes=[x], color="#adb5bd", linewidth=1.5,
57+
alpha=0.6, label="data")
58+
59+
comp_lines = [
60+
plot.add_line(gaussian(x, **p), x_axis=x,
61+
color=c, linewidth=2.0,
62+
label=f"comp {i+1}")
63+
for i, (p, c) in enumerate(zip(INIT_P, COLORS))
64+
]
65+
66+
# Live sum of all components — this IS the fit after pressing 'f'
67+
sum_line = plot.add_line(
68+
sum(gaussian(x, **p) for p in INIT_P), x_axis=x,
69+
color="#e0e0e0", linewidth=1.5, linestyle="dashed", label="sum",
70+
)
71+
72+
# ── GaussianComponent ──────────────────────────────────────────────────────
73+
74+
class GaussianComponent:
75+
"""Manages a PointWidget (peak) + RangeWidget (σ) for one component.
76+
77+
Assign ``.model`` after constructing the ``Model`` so the component
78+
can notify it on every drag frame.
79+
"""
80+
81+
def __init__(self, line, p, color):
82+
self.line = line
83+
self.amp = p["amp"]
84+
self.mu = p["mu"]
85+
self.sigma = p["sigma"]
86+
self.color = color
87+
self.model = None # injected after Model is constructed
88+
self._active = False
89+
self._syncing = False # guard against callback loops
90+
self._pt = None # PointWidget — created once on first toggle
91+
self._rng_w = None # RangeWidget
92+
93+
def component_y(self):
94+
return gaussian(x, self.amp, self.mu, self.sigma)
95+
96+
def toggle(self):
97+
if self._active:
98+
self._pt.hide()
99+
self._rng_w.hide()
100+
self._active = False
101+
else:
102+
if self._pt is None:
103+
self._pt = plot.add_point_widget(self.mu, self.amp,
104+
color=self.color,
105+
show_crosshair=False)
106+
self._rng_w = plot.add_range_widget(
107+
self.mu - self.sigma * _FWHM_K,
108+
self.mu + self.sigma * _FWHM_K,
109+
y=self.amp / 2.0,
110+
color=self.color,
111+
style="fwhm",
112+
)
113+
self._wire()
114+
else:
115+
self._pt.show()
116+
self._rng_w.show()
117+
self._active = True
118+
119+
def _wire(self):
120+
@self._pt.on_changed
121+
def _peak_moved(event):
122+
if self._syncing:
123+
return
124+
self._syncing = True
125+
try:
126+
self.amp = event.data["y"]
127+
self.mu = event.data["x"]
128+
self._rng_w.set(x0=self.mu - self.sigma * _FWHM_K,
129+
x1=self.mu + self.sigma * _FWHM_K,
130+
y=self.amp / 2.0)
131+
self.line.set_data(self.component_y())
132+
if self.model:
133+
self.model.update()
134+
finally:
135+
self._syncing = False
136+
137+
@self._rng_w.on_changed
138+
def _range_moved(event):
139+
if self._syncing:
140+
return
141+
self._syncing = True
142+
try:
143+
x0, x1 = event.data["x0"], event.data["x1"]
144+
self.mu = (x0 + x1) / 2.0
145+
self.sigma = abs(x1 - x0) / (2.0 * _FWHM_K)
146+
self._pt.set(x=self.mu)
147+
self.line.set_data(self.component_y())
148+
if self.model:
149+
self.model.update()
150+
finally:
151+
self._syncing = False
152+
153+
def snap(self, amp: float, mu: float, sigma: float) -> None:
154+
"""Update parameters and snap **all** widgets to the new values.
155+
156+
Creates and shows the point and FWHM range widgets if they do not
157+
exist yet (so pressing **f** always reveals the fitted widths), then
158+
updates their positions. Uses the ``_syncing`` guard so widget
159+
callbacks do not fire during the programmatic update.
160+
"""
161+
self._syncing = True
162+
try:
163+
self.amp = amp
164+
self.mu = mu
165+
self.sigma = sigma
166+
self.line.set_data(self.component_y())
167+
if self._pt is None:
168+
# First fit — create widgets at the fitted position and show them.
169+
self._pt = plot.add_point_widget(self.mu, self.amp,
170+
color=self.color,
171+
show_crosshair=False)
172+
self._rng_w = plot.add_range_widget(
173+
self.mu - self.sigma * _FWHM_K,
174+
self.mu + self.sigma * _FWHM_K,
175+
y=self.amp / 2.0,
176+
color=self.color,
177+
style="fwhm",
178+
)
179+
self._wire()
180+
self._active = True
181+
else:
182+
# Widgets already exist — move them to the new fitted position.
183+
self._pt.set(x=self.mu, y=self.amp)
184+
self._rng_w.set(x0=self.mu - self.sigma * _FWHM_K,
185+
x1=self.mu + self.sigma * _FWHM_K,
186+
y=self.amp / 2.0)
187+
# If the user had hidden the widgets, bring them back.
188+
if not self._active:
189+
self._pt.show()
190+
self._rng_w.show()
191+
self._active = True
192+
finally:
193+
self._syncing = False
194+
195+
# ── Model ──────────────────────────────────────────────────────────────────
196+
197+
class Model:
198+
"""A list of GaussianComponents with a live sum line.
199+
200+
``update()`` redraws the sum line from the current component state and
201+
is called on every drag frame.
202+
203+
``fit()`` runs a least-squares fit, snaps every component (and its
204+
widgets) to the optimal parameters, then calls ``update()`` so the sum
205+
line jumps to the best fit. It is also triggered by pressing **f**.
206+
207+
Parameters
208+
----------
209+
components : list[GaussianComponent]
210+
sum_line : Line1D
211+
Always-live manual-sum / fit-result overlay.
212+
x_data, y_data : ndarray
213+
Observed signal to fit against.
214+
"""
215+
216+
def __init__(self, components, sum_line, x_data, y_data):
217+
self.components = list(components)
218+
self.sum_line = sum_line
219+
self.x_data = x_data
220+
self.y_data = y_data
221+
222+
def update(self):
223+
"""Redraw the sum line as the manual sum of all components."""
224+
self.sum_line.set_data(
225+
sum(c.component_y() for c in self.components)
226+
)
227+
228+
def fit(self):
229+
"""Least-squares fit; snaps components to the result.
230+
231+
Builds a generic n-Gaussian model from the component list and uses
232+
their current state as the initial guess. On success every component
233+
snaps to the fitted (amp, μ, σ) and the sum redraws as the best fit.
234+
On failure the components are left unchanged.
235+
"""
236+
n = len(self.components)
237+
p0 = [v for c in self.components for v in (c.amp, c.mu, c.sigma)]
238+
lo = [v for c in self.components for v in (0, self.x_data[0], 1e-3)]
239+
hi = [v for c in self.components
240+
for v in (np.inf, self.x_data[-1],
241+
self.x_data[-1] - self.x_data[0])]
242+
243+
def _model_fn(x, *params):
244+
return sum(
245+
gaussian(x, params[3 * i], params[3 * i + 1], params[3 * i + 2])
246+
for i in range(n)
247+
)
248+
249+
try:
250+
popt, _ = curve_fit(
251+
_model_fn, self.x_data, self.y_data,
252+
p0=p0, bounds=(lo, hi), maxfev=3000 * n,
253+
)
254+
for i, comp in enumerate(self.components):
255+
comp.snap(popt[3 * i], popt[3 * i + 1], popt[3 * i + 2])
256+
self.update()
257+
except RuntimeError:
258+
pass # leave components unchanged if fit did not converge
259+
260+
# ── Assemble ───────────────────────────────────────────────────────────────
261+
262+
components = [
263+
GaussianComponent(comp_lines[i], INIT_P[i], COLORS[i])
264+
for i in range(2)
265+
]
266+
267+
model = Model(components, sum_line, x, signal)
268+
for comp in components:
269+
comp.model = model
270+
271+
# ── Key binding — press 'f' to fit ─────────────────────────────────────────
272+
273+
@plot.on_key('f')
274+
def _on_fit(event):
275+
model.fit()
276+
277+
# ── Click handlers — toggle widgets per component ─────────────────────────
278+
279+
for comp, line in zip(components, comp_lines):
280+
@line.on_click
281+
def _clicked(event, c=comp):
282+
c.toggle()
283+
284+
fig

0 commit comments

Comments
 (0)