@@ -927,6 +927,23 @@ def output_search_internal(self, search_internal):
927927 obj = search_internal ,
928928 )
929929
930+ @property
931+ def _updater (self ):
932+ if not hasattr (self , "_search_updater" ) or self ._search_updater is None :
933+ from autofit .non_linear .search .updater import SearchUpdater
934+
935+ self ._search_updater = SearchUpdater (
936+ paths = self .paths ,
937+ timer = self .timer ,
938+ search_logger = self .logger ,
939+ plot_results_func = self .plot_results ,
940+ samples_from_func = self .samples_from ,
941+ should_profile = self .should_profile ,
942+ disable_output = self .disable_output ,
943+ iterations_per_full_update = self .iterations_per_full_update ,
944+ )
945+ return self ._search_updater
946+
930947 def perform_update (
931948 self ,
932949 model : AbstractPriorModel ,
@@ -938,153 +955,18 @@ def perform_update(
938955 """
939956 Perform an update of the non-linear search's model-fitting results.
940957
941- This occurs every `iterations_per_full_update` of the non-linear search and once it is complete.
942-
943- The update performs the following tasks (if the settings indicate they should be performed):
944-
945- 1) Visualize the search results.
946- 2) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
947- object.
948- 3) Perform profiling of the analysis object `log_likelihood_function` and ouptut run-time information.
949- 4) Output the `search.summary` file which contains information on model-fitting so far.
950- 5) Output the `model.results` file which contains a concise text summary of the model results so far.
951-
952- Parameters
953- ----------
954- model
955- The model which generates instances for different points in parameter space.
956- analysis
957- Contains the data and the log likelihood function which fits an instance of the model to the data, returning
958- the log likelihood the `NonLinearSearch` maximizes.
959- during_analysis
960- If the update is during a non-linear search, in which case tasks are only performed after a certain number
961- of updates and only a subset of visualization may be performed.
958+ Delegates to :class:`SearchUpdater` which separates each output
959+ concern (samples, latent variables, visualization, profiling,
960+ summary) into its own method.
962961 """
963- self .iterations += self .iterations_per_full_update
964-
965- if not self .disable_output :
966- self .logger .info (
967- f"""Fit Running: Updating results (see output folder)."""
968- )
969-
970- if not isinstance (self .paths , DatabasePaths ) and not isinstance (
971- self .paths , NullPaths
972- ):
973- self .timer .update ()
974-
975- samples = self .samples_from (model = model , search_internal = search_internal )
976- samples_summary = samples .summary ()
977-
978- try :
979- instance = samples_summary .instance
980- except exc .FitException :
981- return samples
982-
983- self .paths .save_samples_summary (samples_summary = samples_summary )
984-
985- samples_save = samples
986-
987- log_message = True
988-
989- if during_analysis :
990- log_message = False
991- elif self .disable_output :
992- log_message = False
993-
994- samples_save = samples_save .samples_above_weight_threshold_from (
995- log_message = log_message
996- )
997- self .paths .save_samples (samples = samples_save )
998-
999- latent_samples = None
1000-
1001- if (during_analysis and conf .instance ["output" ]["latent_during_fit" ]) or (
1002- not during_analysis and conf .instance ["output" ]["latent_after_fit" ]
1003- ):
1004-
1005- if conf .instance ["output" ]["latent_draw_via_pdf" ]:
1006-
1007- total_draws = conf .instance ["output" ]["latent_draw_via_pdf_size" ]
1008-
1009- logger .info (f"Creating latent samples by drawing { total_draws } from the PDF." )
1010-
1011- try :
1012- latent_samples = samples .samples_drawn_randomly_via_pdf_from (total_draws = total_draws )
1013- except AttributeError :
1014- latent_samples = samples_save
1015- logger .info (
1016- "Drawing via PDF not available for this search, "
1017- "using all samples above the samples weight threshold instead."
1018- "" )
1019-
1020- else :
1021-
1022- logger .info (f"Creating latent samples using all samples above the samples weight threshold." )
1023-
1024- latent_samples = samples_save
1025-
1026- latent_samples = analysis .compute_latent_samples (
1027- latent_samples ,
1028- batch_size = fitness .batch_size
1029- )
1030-
1031- if latent_samples :
1032- if not conf .instance ["output" ]["latent_draw_via_pdf" ]:
1033- self .paths .save_latent_samples (latent_samples )
1034- self .paths .save_samples_summary (
1035- latent_samples .summary (),
1036- "latent/latent_summary" ,
1037- )
1038-
1039- start = time .time ()
1040-
1041- self .perform_visualization (
962+ return self ._updater .update (
1042963 model = model ,
1043964 analysis = analysis ,
1044- samples_summary = samples_summary ,
1045965 during_analysis = during_analysis ,
966+ fitness = fitness ,
1046967 search_internal = search_internal ,
1047968 )
1048969
1049- visualization_time = time .time () - start
1050-
1051- if self .should_profile :
1052-
1053- self .logger .debug ("Profiling Maximum Likelihood Model" )
1054-
1055- analysis .profile_log_likelihood_function (
1056- paths = self .paths ,
1057- instance = instance ,
1058- )
1059-
1060- self .logger .debug ("Outputting model result" )
1061-
1062- try :
1063-
1064- parameters = samples .max_log_likelihood (as_instance = False )
1065-
1066- start = time .time ()
1067- figure_of_merit = fitness .call_wrap (parameters )
1068-
1069- # account for asynchronous JAX calls
1070- np .array (figure_of_merit )
1071-
1072- log_likelihood_function_time = time .time () - start
1073-
1074- self .paths .save_summary (
1075- samples = samples ,
1076- latent_samples = latent_samples ,
1077- log_likelihood_function_time = log_likelihood_function_time ,
1078- visualization_time = visualization_time ,
1079- )
1080-
1081- except exc .FitException :
1082- pass
1083-
1084- self ._log_process_state ()
1085-
1086- return samples
1087-
1088970 def perform_visualization (
1089971 self ,
1090972 model : AbstractPriorModel ,
@@ -1098,72 +980,17 @@ def perform_visualization(
1098980 """
1099981 Perform visualization of the non-linear search's model-fitting results.
1100982
1101- This occurs every `iterations_per_full_update` of the non-linear search, when the search is complete and can
1102- also be forced to occur even though a search is completed on a rerun, to update the visualization
1103- with different `matplotlib` settings.
1104-
1105- The update performs the following tasks (if the settings indicate they should be performed):
1106-
1107- 1) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
1108- object.
1109- 2) Visualize the search results.
1110-
1111- Parameters
1112- ----------
1113- model
1114- The model which generates instances for different points in parameter space.
1115- analysis
1116- Contains the data and the log likelihood function which fits an instance of the model to the data, returning
1117- the log likelihood the `NonLinearSearch` maximizes.
1118- samples_summary
1119- The summary of the samples of the non-linear search, which are used for visualization.
1120- during_analysis
1121- If the update is during a non-linear search, in which case tasks are only performed after a certain number
1122- of updates and only a subset of visualization may be performed.
1123- instance
1124- The instance of the model that is used for visualization. If not input, the maximum log likelihood
1125- instance from the samples is used.
983+ Delegates to :class:`SearchUpdater.visualize`.
1126984 """
1127-
1128- self .logger .debug ("Visualizing" )
1129-
1130- paths = paths_override or self .paths
1131-
1132- if instance is None and samples_summary is None :
1133- raise AssertionError (
1134- """
1135- The search's perform_visualization method has been called without an input instance or
1136- samples_summary.
1137-
1138- This should not occur, please ensure one of these inputs is provided.
1139- """
1140- )
1141-
1142- if instance is None :
1143- instance = samples_summary .instance
1144-
1145- if analysis .should_visualize (paths = paths , during_analysis = during_analysis ):
1146- analysis .visualize (
1147- paths = paths ,
1148- instance = instance ,
1149- during_analysis = during_analysis ,
1150- )
1151- analysis .visualize_combined (
1152- paths = paths ,
1153- instance = instance ,
1154- during_analysis = during_analysis ,
1155- )
1156-
1157- if analysis .should_visualize (paths = paths , during_analysis = during_analysis ):
1158- if not isinstance (paths , NullPaths ):
1159- try :
1160- samples = self .samples_from (
1161- model = model , search_internal = search_internal
1162- )
1163-
1164- self .plot_results (samples = samples )
1165- except FileNotFoundError :
1166- pass
985+ self ._updater .visualize (
986+ model = model ,
987+ analysis = analysis ,
988+ during_analysis = during_analysis ,
989+ samples_summary = samples_summary ,
990+ instance = instance ,
991+ paths_override = paths_override ,
992+ search_internal = search_internal ,
993+ )
1167994
1168995 @property
1169996 def should_plot_start_point (self ) -> bool :
0 commit comments