Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion autofit/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
updates:
iterations_per_quick_update: 1e99 # Non-linear search iterations between every quick update, which just displays the maximum likelihood model fit.
iterations_per_full_update: 1e99 # Non-linear search iterations between every full update, which outputs all visuals and result fits (e.g. model.result, search.summary), this exits the search and can be slow.
iterations_per_full_update: 1e99 # Non-linear search iterations between every full update, which outputs all visuals and result fits (e.g. model.result, search.summary), this exits the search and can be slow.
quick_update_background: false # If True, quick-update visualization runs on a background thread so sampling is not blocked.
hpc:
hpc_mode: false # If True, use HPC mode, which disables GUI visualization, logging to screen and other settings which are not suited to running on a super computer.
iterations_per_quick_update: 1e99 # Non-linear search iterations between every quick update, which just displays the maximum likelihood model fit.
Expand Down
10 changes: 10 additions & 0 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,16 @@ def make_result(
analysis=analysis,
)

@property
def supports_background_update(self) -> bool:
"""Whether this analysis supports background quick updates."""
return False

@property
def supports_jax_visualization(self) -> bool:
"""Whether the visualizer can work directly with JAX arrays."""
return False

def perform_quick_update(self, paths, instance):
raise NotImplementedError

Expand Down
34 changes: 30 additions & 4 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
use_jax_vmap : bool = False,
batch_size : Optional[int] = None,
iterations_per_quick_update: Optional[int] = None,
background_quick_update: bool = False,
):
"""
Interfaces with any non-linear search to fit the model to the data and return a log likelihood via
Expand Down Expand Up @@ -129,6 +130,20 @@ def __init__(
self.quick_update_max_lh = -self._xp.inf
self.quick_update_count = 0

self._background_quick_update = None

if background_quick_update and self.iterations_per_quick_update is not None:
from autofit.non_linear.quick_update import BackgroundQuickUpdate

convert_jax = (
getattr(self.analysis, "_use_jax", False)
and not getattr(self.analysis, "supports_jax_visualization", False)
)

self._background_quick_update = BackgroundQuickUpdate(
convert_jax=convert_jax,
)

if self.paths is not None:
self.check_log_likelihood(fitness=self)

Expand Down Expand Up @@ -314,10 +329,15 @@ def manage_quick_update(self, parameters, log_likelihood):

instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters, xp=self._xp)

try:
self.analysis.perform_quick_update(self.paths, instance)
except NotImplementedError:
pass
if self._background_quick_update is not None:
self._background_quick_update.submit(
self.analysis, self.paths, instance,
)
else:
try:
self.analysis.perform_quick_update(self.paths, instance)
except NotImplementedError:
pass

result_info = text_util.result_max_lh_info_from(
max_log_likelihood_sample=self.quick_update_max_lh_parameters.tolist(),
Expand All @@ -333,6 +353,12 @@ def manage_quick_update(self, parameters, log_likelihood):

logger.info(f"Quick update complete in {time.time() - start_time} seconds.")

def shutdown_quick_update(self):
"""Shut down the background quick-update worker, if one is running."""
if self._background_quick_update is not None:
self._background_quick_update.shutdown()
self._background_quick_update = None

@timeout(timeout_seconds)
def __call__(self, parameters, *kwargs):
"""
Expand Down
107 changes: 107 additions & 0 deletions autofit/non_linear/quick_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import copy
import logging
import threading

import numpy as np

logger = logging.getLogger(__name__)


def _convert_jax_to_numpy(instance):
"""
Return a deep copy of *instance* with every JAX array replaced by a
NumPy array. Plain NumPy values and non-array attributes are left
unchanged.

This is used so that the background visualisation thread never
touches JAX / GPU state, which is not thread-safe.
"""
instance = copy.deepcopy(instance)

for attr in vars(instance):
value = getattr(instance, attr)
if hasattr(value, "device"):
setattr(instance, attr, np.asarray(value))

return instance


class BackgroundQuickUpdate:
"""
Runs ``analysis.perform_quick_update`` on a background daemon thread so
that the sampler is not blocked while matplotlib renders and saves plots.

Uses a **latest-only** pattern: if a new best-fit arrives before the
previous visualisation finishes, the stale request is silently replaced.

Parameters
----------
convert_jax
If ``True``, JAX arrays on the model instance are converted to
NumPy before handing them to the worker thread.
"""

def __init__(self, convert_jax: bool = False):
self._convert_jax = convert_jax

self._lock = threading.Lock()
self._pending = None
self._has_work = threading.Event()
self._stop = threading.Event()

self._thread = threading.Thread(
target=self._worker,
daemon=True,
name="quick-update-worker",
)
self._thread.start()

def submit(self, analysis, paths, instance):
"""
Enqueue a quick-update request. If a previous request is still
pending (not yet picked up by the worker), it is replaced.
"""

if self._convert_jax:
instance = _convert_jax_to_numpy(instance)

with self._lock:
self._pending = (analysis, paths, instance)

self._has_work.set()

def shutdown(self, timeout: float = 10.0):
"""Signal the worker to stop after draining pending work."""
self._stop.set()
self._has_work.set()
self._thread.join(timeout=timeout)

def _process_pending(self):
with self._lock:
work = self._pending
self._pending = None

if work is None:
return

analysis, paths, instance = work

try:
analysis.perform_quick_update(paths, instance)
except NotImplementedError:
pass
except Exception:
logger.exception(
"Background quick-update raised an exception (ignored)."
)

def _worker(self):
while True:
self._has_work.wait()
self._has_work.clear()

self._process_pending()

if self._stop.is_set():
self._process_pending()
break
9 changes: 9 additions & 0 deletions autofit/non_linear/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ def __init__(
self.iterations_per_full_update = float((iterations_per_full_update or
conf.instance["general"]["updates"]["iterations_per_full_update"]))

self.quick_update_background = bool(
conf.instance["general"]["updates"].get(
"quick_update_background", False,
)
)

if conf.instance["general"]["hpc"]["hpc_mode"]:
self.iterations_per_quick_update = float(conf.instance["general"]["hpc"][
"iterations_per_quick_update"
Expand Down Expand Up @@ -664,6 +670,9 @@ def start_resume_fit(self, analysis: Analysis, model: AbstractPriorModel) -> Res
analysis=analysis,
)

if hasattr(fitness, "shutdown_quick_update"):
fitness.shutdown_quick_update()

samples = self.perform_update(
model=model,
analysis=analysis,
Expand Down
4 changes: 3 additions & 1 deletion autofit/non_linear/search/nest/nautilus/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _fit(self, model: AbstractPriorModel, analysis):
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
iterations_per_quick_update=self.iterations_per_quick_update,
background_quick_update=self.quick_update_background,
use_jax_vmap=self.use_jax_vmap,
batch_size=self.n_batch,
)
Expand All @@ -210,7 +211,8 @@ def _fit(self, model: AbstractPriorModel, analysis):
paths=self.paths,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
iterations_per_quick_update=self.iterations_per_quick_update
iterations_per_quick_update=self.iterations_per_quick_update,
background_quick_update=self.quick_update_background,
)

search_internal = self.fit_multiprocessing(
Expand Down
Loading
Loading