diff --git a/NAM/container.cpp b/NAM/container.cpp index ee7d9f1..2d25f04 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -61,6 +61,13 @@ void ContainerModel::prewarm() _submodels[active_index].model->prewarm(); } +void ContainerModel::SetPrewarmOnReset(const bool prewarmOnReset) +{ + DSP::SetPrewarmOnReset(prewarmOnReset); + for (auto& submodel : _submodels) + submodel.model->SetPrewarmOnReset(prewarmOnReset); +} + void ContainerModel::Reset(const double sampleRate, const int maxBufferSize) { std::lock_guard lock(_slim_set_mutex); diff --git a/NAM/container.h b/NAM/container.h index dccc914..36aee5d 100644 --- a/NAM/container.h +++ b/NAM/container.h @@ -37,6 +37,7 @@ class ContainerModel : public DSP, public SlimmableModel void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override; void prewarm() override; void Reset(const double sampleRate, const int maxBufferSize) override; + void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; protected: diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index a4040c3..57d370f 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -14,8 +14,27 @@ constexpr const long _INPUT_BUFFER_SAFETY_FACTOR = 32; +namespace +{ + +thread_local bool gPrewarmOnResetDefault = true; + +} // namespace + +nam::ScopedPrewarmOnResetDefault::ScopedPrewarmOnResetDefault(const bool prewarmOnReset) +: mPreviousPrewarmOnReset(gPrewarmOnResetDefault) +{ + gPrewarmOnResetDefault = prewarmOnReset; +} + +nam::ScopedPrewarmOnResetDefault::~ScopedPrewarmOnResetDefault() +{ + gPrewarmOnResetDefault = mPreviousPrewarmOnReset; +} + nam::DSP::DSP(const int in_channels, const int out_channels, const double expected_sample_rate) : mExpectedSampleRate(expected_sample_rate) +, mPrewarmOnReset(gPrewarmOnResetDefault) , mInChannels(in_channels) , mOutChannels(out_channels) { @@ -96,7 +115,18 @@ void nam::DSP::Reset(const double sampleRate, const int maxBufferSize) mHaveExternalSampleRate = true; SetMaxBufferSize(maxBufferSize); - prewarm(); + if (GetPrewarmOnReset()) + prewarm(); +} + +void nam::DSP::SetPrewarmOnReset(const bool prewarmOnReset) +{ + mPrewarmOnReset.store(prewarmOnReset, std::memory_order_release); +} + +bool nam::DSP::GetPrewarmOnReset() const +{ + return mPrewarmOnReset.load(std::memory_order_acquire); } void nam::DSP::SetLoudness(const double loudness) diff --git a/NAM/dsp.h b/NAM/dsp.h index c714a19..16b6fb0 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -39,6 +40,25 @@ namespace wavenet class WaveNet; } // namespace wavenet +/// \brief Temporarily change the thread-local prewarm-on-reset default for newly constructed DSP objects +/// +/// Existing DSP objects are not affected. DSP instances constructed while this object is alive on the current thread +/// copy the scoped default into their instance-level prewarm-on-reset setting. +class ScopedPrewarmOnResetDefault +{ +public: + explicit ScopedPrewarmOnResetDefault(const bool prewarmOnReset); + ~ScopedPrewarmOnResetDefault(); + + ScopedPrewarmOnResetDefault(const ScopedPrewarmOnResetDefault&) = delete; + ScopedPrewarmOnResetDefault& operator=(const ScopedPrewarmOnResetDefault&) = delete; + + bool PreviousPrewarmOnReset() const { return mPreviousPrewarmOnReset; } + +private: + bool mPreviousPrewarmOnReset; +}; + /// \brief Base class for all DSP models /// /// DSP provides the common interface for all neural network-based audio processing models. @@ -133,20 +153,19 @@ class DSP /// \brief General function for resetting the DSP unit /// - /// This doesn't call prewarm(). If you want to do that, then you might want to use ResetAndPrewarm(). - /// See https://github.com/sdatkinson/NeuralAmpModelerCore/issues/96 for the reasoning. + /// By default, this calls prewarm() after updating the sample rate and buffer size. Use SetPrewarmOnReset() to + /// disable or re-enable that behavior for a DSP instance. /// \param sampleRate Current sample rate /// \param maxBufferSize Maximum buffer size to process virtual void Reset(const double sampleRate, const int maxBufferSize); - /// \brief Reset the DSP unit, then prewarm - /// \param sampleRate Current sample rate - /// \param maxBufferSize Maximum buffer size to process - void ResetAndPrewarm(const double sampleRate, const int maxBufferSize) - { - Reset(sampleRate, maxBufferSize); - prewarm(); - } + /// \brief Control whether Reset() calls prewarm() + /// \param prewarmOnReset true for Reset() to call prewarm(), false to skip prewarm() + virtual void SetPrewarmOnReset(const bool prewarmOnReset); + + /// \brief Check whether Reset() calls prewarm() + /// \return true if Reset() calls prewarm() + bool GetPrewarmOnReset() const; /// \brief Set the input level /// \param inputLevel Input level in dBu @@ -177,6 +196,7 @@ class DSP double mExternalSampleRate = -1.0; // The largest buffer I expect to be told to process: int mMaxBufferSize = 0; + std::atomic mPrewarmOnReset; /// \brief Get how many samples should be processed for the model to be considered "warmed up" /// diff --git a/NAM/get_dsp.cpp b/NAM/get_dsp.cpp index 3aa8592..6439364 100644 --- a/NAM/get_dsp.cpp +++ b/NAM/get_dsp.cpp @@ -139,26 +139,42 @@ std::vector GetWeights(nlohmann::json const& j) throw std::runtime_error("Corrupted model file is missing weights."); } -std::unique_ptr get_dsp(const std::filesystem::path config_filename) +void populate_dsp_data(const nlohmann::json& config, dspData& returnedConfig) +{ + verify_config_version(config["version"].get()); + + nlohmann::json config_json = config["config"]; + std::vector weights = GetWeights(config); + + returnedConfig.version = config["version"].get(); + returnedConfig.architecture = config["architecture"].get(); + returnedConfig.config = config_json; + returnedConfig.metadata = config.value("metadata", nlohmann::json()); + returnedConfig.weights = weights; + returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(config); +} + +std::unique_ptr get_dsp(const std::filesystem::path config_filename, DspLoadOptions options) { dspData temp; - return get_dsp(config_filename, temp); + return get_dsp(config_filename, temp, options); } -std::unique_ptr get_dsp(const nlohmann::json& config) +std::unique_ptr get_dsp(const nlohmann::json& config, DspLoadOptions options) { dspData temp; - return get_dsp(config, temp); + return get_dsp(config, temp, options); } -std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig) +std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig, + DspLoadOptions options) { if (!std::filesystem::exists(config_filename)) throw std::runtime_error("Config file doesn't exist!\n"); std::ifstream i(config_filename); nlohmann::json j; i >> j; - get_dsp(j, returnedConfig); + populate_dsp_data(j, returnedConfig); /*Copy to a new dsp_config object for get_dsp below, since not sure if weights actually get modified as being non-const references on some @@ -166,24 +182,12 @@ std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspDat We need to return unmodified version of dsp_config via returnedConfig.*/ dspData conf = returnedConfig; - return get_dsp(conf); + return get_dsp(conf, options); } -std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig) +std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig, DspLoadOptions options) { - verify_config_version(config["version"].get()); - - auto architecture = config["architecture"]; - nlohmann::json config_json = config["config"]; - std::vector weights = GetWeights(config); - - // Assign values to returnedConfig - returnedConfig.version = config["version"].get(); - returnedConfig.architecture = config["architecture"].get(); - returnedConfig.config = config_json; - returnedConfig.metadata = config.value("metadata", nlohmann::json()); - returnedConfig.weights = weights; - returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(config); + populate_dsp_data(config, returnedConfig); /*Copy to a new dsp_config object for get_dsp below, since not sure if weights actually get modified as being non-const references on some @@ -191,7 +195,7 @@ std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConf We need to return unmodified version of dsp_config via returnedConfig.*/ dspData conf = returnedConfig; - return get_dsp(conf); + return get_dsp(conf, options); } // ============================================================================= @@ -224,9 +228,6 @@ std::unique_ptr create_dsp(std::unique_ptr config, std::vector { auto out = config->create(std::move(weights), metadata.sample_rate); apply_metadata(*out, metadata); - // "pre-warm" the model to settle initial conditions - // Can this be removed now that it's part of Reset()? - out->prewarm(); return out; } @@ -234,7 +235,10 @@ std::unique_ptr create_dsp(std::unique_ptr config, std::vector // get_dsp(dspData&) — now uses unified path // ============================================================================= -std::unique_ptr get_dsp(dspData& conf) +namespace +{ + +std::unique_ptr get_dsp_with_current_prewarm_default(dspData& conf) { verify_config_version(conf.version); @@ -259,6 +263,20 @@ std::unique_ptr get_dsp(dspData& conf) return create_dsp(std::move(model_config), std::move(conf.weights), metadata); } +} // anonymous namespace + +std::unique_ptr get_dsp(dspData& conf, DspLoadOptions options) +{ + if (!options.prewarm.has_value()) + return get_dsp_with_current_prewarm_default(conf); + + ScopedPrewarmOnResetDefault scoped_prewarm_default(*options.prewarm); + auto dsp = get_dsp_with_current_prewarm_default(conf); + if (dsp != nullptr) + dsp->SetPrewarmOnReset(scoped_prewarm_default.PreviousPrewarmOnReset()); + return dsp; +} + double get_sample_rate_from_nam_file(const nlohmann::json& j) { if (j.find("sample_rate") != j.end()) diff --git a/NAM/get_dsp.h b/NAM/get_dsp.h index da874fe..c9e7941 100644 --- a/NAM/get_dsp.h +++ b/NAM/get_dsp.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "dsp.h" @@ -65,34 +66,52 @@ void verify_config_version(const std::string versionStr); const std::string LATEST_FULLY_SUPPORTED_NAM_FILE_VERSION = "0.7.0"; const std::string EARLIEST_SUPPORTED_NAM_FILE_VERSION = "0.5.0"; +/// \brief Options that control DSP loading behavior +struct DspLoadOptions +{ + /// \brief Whether to override the current prewarm-on-reset context during loading + /// + /// std::nullopt leaves the current thread-local context unchanged. Set this to false to avoid expensive prewarm work + /// during get_dsp(), or true to force prewarm during get_dsp(). When an override is provided, the returned model is + /// restored to the caller's previous prewarm-on-reset default before get_dsp() returns. + std::optional prewarm = std::nullopt; +}; + /// \brief Get NAM from a .nam file at the provided location /// \param config_filename Path to the .nam model file +/// \param options Loading options /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const std::filesystem::path config_filename); +std::unique_ptr get_dsp(const std::filesystem::path config_filename, DspLoadOptions options = DspLoadOptions()); /// \brief Get NAM from a provided configuration struct /// \param conf DSP data structure containing model configuration and weights +/// \param options Loading options /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(dspData& conf); +std::unique_ptr get_dsp(dspData& conf, DspLoadOptions options = DspLoadOptions()); /// \brief Get NAM from a .nam file and store its configuration /// /// Creates an instance of DSP and also returns a dspData struct that holds the data of the model. /// \param config_filename Path to the .nam model file /// \param returnedConfig Output parameter that will be filled with the model data +/// \param options Loading options /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig); +std::unique_ptr get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig, + DspLoadOptions options = DspLoadOptions()); /// \brief Get NAM from a provided configuration JSON object /// \param config JSON configuration object /// \param returnedConfig Output parameter that will be filled with the model data +/// \param options Loading options /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig); +std::unique_ptr get_dsp(const nlohmann::json& config, dspData& returnedConfig, + DspLoadOptions options = DspLoadOptions()); /// \brief Get NAM from a provided configuration JSON object (convenience overload) /// \param config JSON configuration object +/// \param options Loading options /// \return Unique pointer to a DSP object -std::unique_ptr get_dsp(const nlohmann::json& config); +std::unique_ptr get_dsp(const nlohmann::json& config, DspLoadOptions options = DspLoadOptions()); /// \brief Get sample rate from a .nam file /// \param j JSON object from the .nam file diff --git a/NAM/model_config.h b/NAM/model_config.h index 32825a5..d9653da 100644 --- a/NAM/model_config.h +++ b/NAM/model_config.h @@ -106,7 +106,7 @@ struct ConfigParserHelper /// \brief Construct a DSP object from a typed config, weights, and metadata /// /// This is the single construction path used by both JSON and binary loaders. -/// Handles construction, metadata application, and prewarm. +/// Handles construction and metadata application. /// \param config Architecture-specific configuration (abstract base) /// \param weights Model weights (taken by value to allow move for WaveNet) /// \param metadata Model metadata (version, sample rate, loudness, levels) diff --git a/NAM/wavenet/model.cpp b/NAM/wavenet/model.cpp index eaf74ad..9437e1e 100644 --- a/NAM/wavenet/model.cpp +++ b/NAM/wavenet/model.cpp @@ -689,6 +689,13 @@ void nam::wavenet::WaveNet::SetMaxBufferSize(const int maxBufferSize) } } +void nam::wavenet::WaveNet::SetPrewarmOnReset(const bool prewarmOnReset) +{ + DSP::SetPrewarmOnReset(prewarmOnReset); + if (this->_condition_dsp != nullptr) + this->_condition_dsp->SetPrewarmOnReset(prewarmOnReset); +} + void nam::wavenet::WaveNet::_process_condition(const int num_frames) { if (this->_condition_dsp == nullptr) diff --git a/NAM/wavenet/model.h b/NAM/wavenet/model.h index 968baf8..508f909 100644 --- a/NAM/wavenet/model.h +++ b/NAM/wavenet/model.h @@ -58,6 +58,8 @@ class WaveNet : public DSP /// \param num_frames Number of frames to process void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override; + void SetPrewarmOnReset(const bool prewarmOnReset) override; + /// \brief Set model weights from a vector /// \param weights Vector containing all model weights void set_weights_(std::vector& weights); diff --git a/NAM/wavenet/slimmable.cpp b/NAM/wavenet/slimmable.cpp index 19b019e..69f077b 100644 --- a/NAM/wavenet/slimmable.cpp +++ b/NAM/wavenet/slimmable.cpp @@ -410,8 +410,10 @@ std::unique_ptr SlimmableWavenet::_create_wavenet_for_channels(const std::v condition_dsp = get_dsp(_condition_dsp_json); double sampleRate = _current_sample_rate > 0 ? _current_sample_rate : GetExpectedSampleRate(); - return std::make_unique(_in_channels, *params_ptr, _head_scale, _with_head, std::nullopt, - std::move(weights), std::move(condition_dsp), sampleRate); + auto model = std::make_unique(_in_channels, *params_ptr, _head_scale, _with_head, std::nullopt, + std::move(weights), std::move(condition_dsp), sampleRate); + model->SetPrewarmOnReset(GetPrewarmOnReset()); + return model; } void SlimmableWavenet::_rebuild_model(const std::vector& target_channels) @@ -481,6 +483,15 @@ void SlimmableWavenet::Reset(const double sampleRate, const int maxBufferSize) pending->model->Reset(sampleRate, maxBufferSize); } +void SlimmableWavenet::SetPrewarmOnReset(const bool prewarmOnReset) +{ + DSP::SetPrewarmOnReset(prewarmOnReset); + if (_active_model) + _active_model->SetPrewarmOnReset(prewarmOnReset); + if (auto pending = _pending_load_acquire()) + pending->model->SetPrewarmOnReset(prewarmOnReset); +} + void SlimmableWavenet::SetSlimmableSize(const double val) { const size_t num_arrays = _original_params.size(); diff --git a/NAM/wavenet/slimmable.h b/NAM/wavenet/slimmable.h index f0687c0..bc416ac 100644 --- a/NAM/wavenet/slimmable.h +++ b/NAM/wavenet/slimmable.h @@ -48,6 +48,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override; void prewarm() override; void Reset(const double sampleRate, const int maxBufferSize) override; + void SetPrewarmOnReset(const bool prewarmOnReset) override; void SetSlimmableSize(const double val) override; protected: diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 0f9d50a..e765ddf 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -84,6 +84,9 @@ int main() test_dsp::test_has_output_level(); test_dsp::test_set_input_level(); test_dsp::test_set_output_level(); + test_dsp::test_reset_prewarm_on_reset_default(); + test_dsp::test_set_prewarm_on_reset(); + test_dsp::test_scoped_prewarm_on_reset_default(); test_linear::test_direct_known_values(); test_linear::test_fft_matches_direct_irregular_chunks(); @@ -285,6 +288,11 @@ int main() test_get_dsp::test_version_too_early(); test_get_dsp::test_is_version_supported_core_behavior(); test_get_dsp::test_register_custom_version_support_checker(); + test_get_dsp::test_get_dsp_default_allows_constructor_reset_prewarm(); + test_get_dsp::test_get_dsp_default_inherits_scoped_prewarm_default(); + test_get_dsp::test_get_dsp_prewarm_option_suppresses_constructor_reset_prewarm(); + test_get_dsp::test_get_dsp_prewarm_option_forces_constructor_reset_prewarm(); + test_get_dsp::test_get_dsp_with_returned_config_constructs_once(); // Finally, some end-to-end tests. test_get_dsp::test_load_and_process_nam_files(); diff --git a/tools/test/test_container.cpp b/tools/test/test_container.cpp index a89e8a5..46b1752 100644 --- a/tools/test/test_container.cpp +++ b/tools/test/test_container.cpp @@ -380,7 +380,7 @@ void test_container_default_is_max_size() NAM_SAMPLE* out_ptr; // Ensure both predictions start from identical model state. - dsp->ResetAndPrewarm(sample_rate, buffer_size); + dsp->Reset(sample_rate, buffer_size); // Process with default (should be max size) out_ptr = out_default.data(); dsp->process(&in_ptr, &out_ptr, buffer_size); @@ -389,7 +389,7 @@ void test_container_default_is_max_size() auto* slimmable = dynamic_cast(dsp.get()); assert(slimmable != nullptr); slimmable->SetSlimmableSize(1.0); - dsp->ResetAndPrewarm(sample_rate, buffer_size); + dsp->Reset(sample_rate, buffer_size); out_ptr = out_max.data(); dsp->process(&in_ptr, &out_ptr, buffer_size); diff --git a/tools/test/test_dsp.cpp b/tools/test/test_dsp.cpp index d019a87..334461d 100644 --- a/tools/test/test_dsp.cpp +++ b/tools/test/test_dsp.cpp @@ -1,10 +1,24 @@ // Tests for dsp #include "NAM/dsp.h" +#include #include namespace test_dsp { +class PrewarmCountingDSP : public nam::DSP +{ +public: + PrewarmCountingDSP() + : nam::DSP(1, 1, 48000.0) + { + } + + void prewarm() override { prewarm_count++; } + + int prewarm_count = 0; +}; + // Simplest test: can I construct something! void test_construct() { @@ -91,6 +105,46 @@ void test_set_output_level() myDsp.SetOutputLevel(19.0); } +void test_reset_prewarm_on_reset_default() +{ + PrewarmCountingDSP dsp; + assert(dsp.GetPrewarmOnReset()); + + dsp.Reset(48000.0, 64); + + assert(dsp.prewarm_count == 1); +} + +void test_set_prewarm_on_reset() +{ + PrewarmCountingDSP dsp; + + dsp.SetPrewarmOnReset(false); + dsp.Reset(48000.0, 64); + assert(dsp.prewarm_count == 0); + + dsp.SetPrewarmOnReset(true); + dsp.Reset(48000.0, 64); + assert(dsp.prewarm_count == 1); +} + +void test_scoped_prewarm_on_reset_default() +{ + PrewarmCountingDSP before_scope; + assert(before_scope.GetPrewarmOnReset()); + + { + nam::ScopedPrewarmOnResetDefault scoped_default(false); + PrewarmCountingDSP in_scope; + assert(!in_scope.GetPrewarmOnReset()); + in_scope.Reset(48000.0, 64); + assert(in_scope.prewarm_count == 0); + } + + PrewarmCountingDSP after_scope; + assert(after_scope.GetPrewarmOnReset()); +} + void test_process_multi_channel() { const int in_channels = 2; diff --git a/tools/test/test_get_dsp.cpp b/tools/test/test_get_dsp.cpp index 150f50f..3cfd368 100644 --- a/tools/test/test_get_dsp.cpp +++ b/tools/test/test_get_dsp.cpp @@ -10,6 +10,7 @@ #include "json.hpp" #include "NAM/get_dsp.h" +#include "NAM/registry.h" namespace test_get_dsp { @@ -46,6 +47,50 @@ nam::dspData _GetConfig(const std::string& configStr = basicConfigStr) return returnedConfig; } +namespace +{ + +constexpr const char* kConstructorResetArchitecture = "ConstructorResetPrewarmPolicyTest"; +int gConstructorResetConstructCount = 0; + +class ConstructorResetDSP : public nam::DSP +{ +public: + explicit ConstructorResetDSP(const double expected_sample_rate) + : nam::DSP(1, 1, expected_sample_rate) + { + gConstructorResetConstructCount++; + Reset(expected_sample_rate, 16); + } + + void prewarm() override { prewarm_count++; } + + int prewarm_count = 0; +}; + +std::unique_ptr ConstructorResetFactory(const nlohmann::json& config, std::vector& weights, + const double expectedSampleRate) +{ + (void)config; + (void)weights; + return std::make_unique(expectedSampleRate); +} + +static nam::factory::Helper _register_ConstructorResetPrewarmPolicyTest(kConstructorResetArchitecture, + ConstructorResetFactory); + +nlohmann::json build_constructor_reset_config() +{ + return nlohmann::json{{"version", "0.7.0"}, + {"metadata", nlohmann::json::object()}, + {"architecture", kConstructorResetArchitecture}, + {"config", nlohmann::json::object()}, + {"weights", nlohmann::json::array()}, + {"sample_rate", 48000}}; +} + +} // namespace + void test_gets_input_level() { nam::dspData config = _GetConfig(); @@ -273,4 +318,80 @@ void test_register_custom_version_support_checker() assert(nam::is_version_supported("DEMO::1.0.3") == nam::Supported::PARTIAL); assert(nam::is_version_supported("DEMO::2.0.0") == nam::Supported::NO); } -}; // namespace test_get_dsp \ No newline at end of file + +void test_get_dsp_default_allows_constructor_reset_prewarm() +{ + gConstructorResetConstructCount = 0; + + auto dsp = nam::get_dsp(build_constructor_reset_config()); + auto* typed = dynamic_cast(dsp.get()); + + assert(typed != nullptr); + assert(gConstructorResetConstructCount == 1); + assert(typed->prewarm_count == 1); + assert(typed->GetPrewarmOnReset()); +} + +void test_get_dsp_default_inherits_scoped_prewarm_default() +{ + gConstructorResetConstructCount = 0; + + nam::ScopedPrewarmOnResetDefault scoped_default(false); + auto dsp = nam::get_dsp(build_constructor_reset_config()); + auto* typed = dynamic_cast(dsp.get()); + + assert(typed != nullptr); + assert(gConstructorResetConstructCount == 1); + assert(typed->prewarm_count == 0); + assert(!typed->GetPrewarmOnReset()); +} + +void test_get_dsp_prewarm_option_suppresses_constructor_reset_prewarm() +{ + gConstructorResetConstructCount = 0; + + nam::DspLoadOptions options; + options.prewarm = false; + auto dsp = nam::get_dsp(build_constructor_reset_config(), options); + auto* typed = dynamic_cast(dsp.get()); + + assert(typed != nullptr); + assert(gConstructorResetConstructCount == 1); + assert(typed->prewarm_count == 0); + assert(typed->GetPrewarmOnReset()); + + typed->Reset(48000.0, 16); + assert(typed->prewarm_count == 1); +} + +void test_get_dsp_prewarm_option_forces_constructor_reset_prewarm() +{ + gConstructorResetConstructCount = 0; + + nam::ScopedPrewarmOnResetDefault scoped_default(false); + nam::DspLoadOptions options; + options.prewarm = true; + auto dsp = nam::get_dsp(build_constructor_reset_config(), options); + auto* typed = dynamic_cast(dsp.get()); + + assert(typed != nullptr); + assert(gConstructorResetConstructCount == 1); + assert(typed->prewarm_count == 1); + assert(!typed->GetPrewarmOnReset()); + + typed->Reset(48000.0, 16); + assert(typed->prewarm_count == 1); +} + +void test_get_dsp_with_returned_config_constructs_once() +{ + gConstructorResetConstructCount = 0; + + nam::dspData returned_config; + auto dsp = nam::get_dsp(build_constructor_reset_config(), returned_config); + + assert(dsp != nullptr); + assert(gConstructorResetConstructCount == 1); + assert(returned_config.architecture == kConstructorResetArchitecture); +} +}; // namespace test_get_dsp