Skip to content

Commit f27b456

Browse files
authored
Merge pull request #1207 from PyAutoLabs/feature/search-update-refactor
refactor: extract SearchUpdater class from NonLinearSearch
2 parents c0b5f17 + be1bb75 commit f27b456

2 files changed

Lines changed: 373 additions & 205 deletions

File tree

autofit/non_linear/search/abstract_search.py

Lines changed: 32 additions & 205 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,23 @@ def output_search_internal(self, search_internal):
927927
obj=search_internal,
928928
)
929929

930+
@property
931+
def _updater(self):
932+
if not hasattr(self, "_search_updater") or self._search_updater is None:
933+
from autofit.non_linear.search.updater import SearchUpdater
934+
935+
self._search_updater = SearchUpdater(
936+
paths=self.paths,
937+
timer=self.timer,
938+
search_logger=self.logger,
939+
plot_results_func=self.plot_results,
940+
samples_from_func=self.samples_from,
941+
should_profile=self.should_profile,
942+
disable_output=self.disable_output,
943+
iterations_per_full_update=self.iterations_per_full_update,
944+
)
945+
return self._search_updater
946+
930947
def perform_update(
931948
self,
932949
model: AbstractPriorModel,
@@ -938,153 +955,18 @@ def perform_update(
938955
"""
939956
Perform an update of the non-linear search's model-fitting results.
940957
941-
This occurs every `iterations_per_full_update` of the non-linear search and once it is complete.
942-
943-
The update performs the following tasks (if the settings indicate they should be performed):
944-
945-
1) Visualize the search results.
946-
2) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
947-
object.
948-
3) Perform profiling of the analysis object `log_likelihood_function` and ouptut run-time information.
949-
4) Output the `search.summary` file which contains information on model-fitting so far.
950-
5) Output the `model.results` file which contains a concise text summary of the model results so far.
951-
952-
Parameters
953-
----------
954-
model
955-
The model which generates instances for different points in parameter space.
956-
analysis
957-
Contains the data and the log likelihood function which fits an instance of the model to the data, returning
958-
the log likelihood the `NonLinearSearch` maximizes.
959-
during_analysis
960-
If the update is during a non-linear search, in which case tasks are only performed after a certain number
961-
of updates and only a subset of visualization may be performed.
958+
Delegates to :class:`SearchUpdater` which separates each output
959+
concern (samples, latent variables, visualization, profiling,
960+
summary) into its own method.
962961
"""
963-
self.iterations += self.iterations_per_full_update
964-
965-
if not self.disable_output:
966-
self.logger.info(
967-
f"""Fit Running: Updating results (see output folder)."""
968-
)
969-
970-
if not isinstance(self.paths, DatabasePaths) and not isinstance(
971-
self.paths, NullPaths
972-
):
973-
self.timer.update()
974-
975-
samples = self.samples_from(model=model, search_internal=search_internal)
976-
samples_summary = samples.summary()
977-
978-
try:
979-
instance = samples_summary.instance
980-
except exc.FitException:
981-
return samples
982-
983-
self.paths.save_samples_summary(samples_summary=samples_summary)
984-
985-
samples_save = samples
986-
987-
log_message = True
988-
989-
if during_analysis:
990-
log_message = False
991-
elif self.disable_output:
992-
log_message = False
993-
994-
samples_save = samples_save.samples_above_weight_threshold_from(
995-
log_message=log_message
996-
)
997-
self.paths.save_samples(samples=samples_save)
998-
999-
latent_samples = None
1000-
1001-
if (during_analysis and conf.instance["output"]["latent_during_fit"]) or (
1002-
not during_analysis and conf.instance["output"]["latent_after_fit"]
1003-
):
1004-
1005-
if conf.instance["output"]["latent_draw_via_pdf"]:
1006-
1007-
total_draws = conf.instance["output"]["latent_draw_via_pdf_size"]
1008-
1009-
logger.info(f"Creating latent samples by drawing {total_draws} from the PDF.")
1010-
1011-
try:
1012-
latent_samples = samples.samples_drawn_randomly_via_pdf_from(total_draws=total_draws)
1013-
except AttributeError:
1014-
latent_samples = samples_save
1015-
logger.info(
1016-
"Drawing via PDF not available for this search, "
1017-
"using all samples above the samples weight threshold instead."
1018-
"")
1019-
1020-
else:
1021-
1022-
logger.info(f"Creating latent samples using all samples above the samples weight threshold.")
1023-
1024-
latent_samples = samples_save
1025-
1026-
latent_samples = analysis.compute_latent_samples(
1027-
latent_samples,
1028-
batch_size=fitness.batch_size
1029-
)
1030-
1031-
if latent_samples:
1032-
if not conf.instance["output"]["latent_draw_via_pdf"]:
1033-
self.paths.save_latent_samples(latent_samples)
1034-
self.paths.save_samples_summary(
1035-
latent_samples.summary(),
1036-
"latent/latent_summary",
1037-
)
1038-
1039-
start = time.time()
1040-
1041-
self.perform_visualization(
962+
return self._updater.update(
1042963
model=model,
1043964
analysis=analysis,
1044-
samples_summary=samples_summary,
1045965
during_analysis=during_analysis,
966+
fitness=fitness,
1046967
search_internal=search_internal,
1047968
)
1048969

1049-
visualization_time = time.time() - start
1050-
1051-
if self.should_profile:
1052-
1053-
self.logger.debug("Profiling Maximum Likelihood Model")
1054-
1055-
analysis.profile_log_likelihood_function(
1056-
paths=self.paths,
1057-
instance=instance,
1058-
)
1059-
1060-
self.logger.debug("Outputting model result")
1061-
1062-
try:
1063-
1064-
parameters = samples.max_log_likelihood(as_instance=False)
1065-
1066-
start = time.time()
1067-
figure_of_merit = fitness.call_wrap(parameters)
1068-
1069-
# account for asynchronous JAX calls
1070-
np.array(figure_of_merit)
1071-
1072-
log_likelihood_function_time = time.time() - start
1073-
1074-
self.paths.save_summary(
1075-
samples=samples,
1076-
latent_samples=latent_samples,
1077-
log_likelihood_function_time=log_likelihood_function_time,
1078-
visualization_time=visualization_time,
1079-
)
1080-
1081-
except exc.FitException:
1082-
pass
1083-
1084-
self._log_process_state()
1085-
1086-
return samples
1087-
1088970
def perform_visualization(
1089971
self,
1090972
model: AbstractPriorModel,
@@ -1098,72 +980,17 @@ def perform_visualization(
1098980
"""
1099981
Perform visualization of the non-linear search's model-fitting results.
1100982
1101-
This occurs every `iterations_per_full_update` of the non-linear search, when the search is complete and can
1102-
also be forced to occur even though a search is completed on a rerun, to update the visualization
1103-
with different `matplotlib` settings.
1104-
1105-
The update performs the following tasks (if the settings indicate they should be performed):
1106-
1107-
1) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
1108-
object.
1109-
2) Visualize the search results.
1110-
1111-
Parameters
1112-
----------
1113-
model
1114-
The model which generates instances for different points in parameter space.
1115-
analysis
1116-
Contains the data and the log likelihood function which fits an instance of the model to the data, returning
1117-
the log likelihood the `NonLinearSearch` maximizes.
1118-
samples_summary
1119-
The summary of the samples of the non-linear search, which are used for visualization.
1120-
during_analysis
1121-
If the update is during a non-linear search, in which case tasks are only performed after a certain number
1122-
of updates and only a subset of visualization may be performed.
1123-
instance
1124-
The instance of the model that is used for visualization. If not input, the maximum log likelihood
1125-
instance from the samples is used.
983+
Delegates to :class:`SearchUpdater.visualize`.
1126984
"""
1127-
1128-
self.logger.debug("Visualizing")
1129-
1130-
paths = paths_override or self.paths
1131-
1132-
if instance is None and samples_summary is None:
1133-
raise AssertionError(
1134-
"""
1135-
The search's perform_visualization method has been called without an input instance or
1136-
samples_summary.
1137-
1138-
This should not occur, please ensure one of these inputs is provided.
1139-
"""
1140-
)
1141-
1142-
if instance is None:
1143-
instance = samples_summary.instance
1144-
1145-
if analysis.should_visualize(paths=paths, during_analysis=during_analysis):
1146-
analysis.visualize(
1147-
paths=paths,
1148-
instance=instance,
1149-
during_analysis=during_analysis,
1150-
)
1151-
analysis.visualize_combined(
1152-
paths=paths,
1153-
instance=instance,
1154-
during_analysis=during_analysis,
1155-
)
1156-
1157-
if analysis.should_visualize(paths=paths, during_analysis=during_analysis):
1158-
if not isinstance(paths, NullPaths):
1159-
try:
1160-
samples = self.samples_from(
1161-
model=model, search_internal=search_internal
1162-
)
1163-
1164-
self.plot_results(samples=samples)
1165-
except FileNotFoundError:
1166-
pass
985+
self._updater.visualize(
986+
model=model,
987+
analysis=analysis,
988+
during_analysis=during_analysis,
989+
samples_summary=samples_summary,
990+
instance=instance,
991+
paths_override=paths_override,
992+
search_internal=search_internal,
993+
)
1167994

1168995
@property
1169996
def should_plot_start_point(self) -> bool:

0 commit comments

Comments
 (0)