Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions NAM/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ void ContainerModel::prewarm()
_submodels[active_index].model->prewarm();
}

void ContainerModel::SetPrewarmOnReset(const bool prewarmOnReset)
{
DSP::SetPrewarmOnReset(prewarmOnReset);
for (auto& submodel : _submodels)
submodel.model->SetPrewarmOnReset(prewarmOnReset);
}

void ContainerModel::Reset(const double sampleRate, const int maxBufferSize)
{
std::lock_guard<std::mutex> lock(_slim_set_mutex);
Expand Down
1 change: 1 addition & 0 deletions NAM/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ContainerModel : public DSP, public SlimmableModel
void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override;
void prewarm() override;
void Reset(const double sampleRate, const int maxBufferSize) override;
void SetPrewarmOnReset(const bool prewarmOnReset) override;
void SetSlimmableSize(const double val) override;

protected:
Expand Down
32 changes: 31 additions & 1 deletion NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,27 @@

constexpr const long _INPUT_BUFFER_SAFETY_FACTOR = 32;

namespace
{

thread_local bool gPrewarmOnResetDefault = true;

} // namespace

nam::ScopedPrewarmOnResetDefault::ScopedPrewarmOnResetDefault(const bool prewarmOnReset)
: mPreviousPrewarmOnReset(gPrewarmOnResetDefault)
{
gPrewarmOnResetDefault = prewarmOnReset;
}

nam::ScopedPrewarmOnResetDefault::~ScopedPrewarmOnResetDefault()
{
gPrewarmOnResetDefault = mPreviousPrewarmOnReset;
}

nam::DSP::DSP(const int in_channels, const int out_channels, const double expected_sample_rate)
: mExpectedSampleRate(expected_sample_rate)
, mPrewarmOnReset(gPrewarmOnResetDefault)
, mInChannels(in_channels)
, mOutChannels(out_channels)
{
Expand Down Expand Up @@ -96,7 +115,18 @@ void nam::DSP::Reset(const double sampleRate, const int maxBufferSize)
mHaveExternalSampleRate = true;
SetMaxBufferSize(maxBufferSize);

prewarm();
if (GetPrewarmOnReset())
prewarm();
}

void nam::DSP::SetPrewarmOnReset(const bool prewarmOnReset)
{
mPrewarmOnReset.store(prewarmOnReset, std::memory_order_release);
}

bool nam::DSP::GetPrewarmOnReset() const
{
return mPrewarmOnReset.load(std::memory_order_acquire);
}

void nam::DSP::SetLoudness(const double loudness)
Expand Down
40 changes: 30 additions & 10 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <atomic>
#include <filesystem>
#include <iterator>
#include <memory>
Expand Down Expand Up @@ -39,6 +40,25 @@ namespace wavenet
class WaveNet;
} // namespace wavenet

/// \brief Temporarily change the thread-local prewarm-on-reset default for newly constructed DSP objects
///
/// Existing DSP objects are not affected. DSP instances constructed while this object is alive on the current thread
/// copy the scoped default into their instance-level prewarm-on-reset setting.
class ScopedPrewarmOnResetDefault
{
public:
explicit ScopedPrewarmOnResetDefault(const bool prewarmOnReset);
~ScopedPrewarmOnResetDefault();

ScopedPrewarmOnResetDefault(const ScopedPrewarmOnResetDefault&) = delete;
ScopedPrewarmOnResetDefault& operator=(const ScopedPrewarmOnResetDefault&) = delete;

bool PreviousPrewarmOnReset() const { return mPreviousPrewarmOnReset; }

private:
bool mPreviousPrewarmOnReset;
};

/// \brief Base class for all DSP models
///
/// DSP provides the common interface for all neural network-based audio processing models.
Expand Down Expand Up @@ -133,20 +153,19 @@ class DSP

/// \brief General function for resetting the DSP unit
///
/// This doesn't call prewarm(). If you want to do that, then you might want to use ResetAndPrewarm().
/// See https://github.com/sdatkinson/NeuralAmpModelerCore/issues/96 for the reasoning.
/// By default, this calls prewarm() after updating the sample rate and buffer size. Use SetPrewarmOnReset() to
/// disable or re-enable that behavior for a DSP instance.
/// \param sampleRate Current sample rate
/// \param maxBufferSize Maximum buffer size to process
virtual void Reset(const double sampleRate, const int maxBufferSize);

/// \brief Reset the DSP unit, then prewarm
/// \param sampleRate Current sample rate
/// \param maxBufferSize Maximum buffer size to process
void ResetAndPrewarm(const double sampleRate, const int maxBufferSize)
{
Reset(sampleRate, maxBufferSize);
prewarm();
}
/// \brief Control whether Reset() calls prewarm()
/// \param prewarmOnReset true for Reset() to call prewarm(), false to skip prewarm()
virtual void SetPrewarmOnReset(const bool prewarmOnReset);

/// \brief Check whether Reset() calls prewarm()
/// \return true if Reset() calls prewarm()
bool GetPrewarmOnReset() const;

/// \brief Set the input level
/// \param inputLevel Input level in dBu
Expand Down Expand Up @@ -177,6 +196,7 @@ class DSP
double mExternalSampleRate = -1.0;
// The largest buffer I expect to be told to process:
int mMaxBufferSize = 0;
std::atomic<bool> mPrewarmOnReset;

/// \brief Get how many samples should be processed for the model to be considered "warmed up"
///
Expand Down
70 changes: 44 additions & 26 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,59 +139,63 @@ std::vector<float> GetWeights(nlohmann::json const& j)
throw std::runtime_error("Corrupted model file is missing weights.");
}

std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename)
void populate_dsp_data(const nlohmann::json& config, dspData& returnedConfig)
{
verify_config_version(config["version"].get<std::string>());

nlohmann::json config_json = config["config"];
std::vector<float> weights = GetWeights(config);

returnedConfig.version = config["version"].get<std::string>();
returnedConfig.architecture = config["architecture"].get<std::string>();
returnedConfig.config = config_json;
returnedConfig.metadata = config.value("metadata", nlohmann::json());
returnedConfig.weights = weights;
returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(config);
}

std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, DspLoadOptions options)
{
dspData temp;
return get_dsp(config_filename, temp);
return get_dsp(config_filename, temp, options);
}

std::unique_ptr<DSP> get_dsp(const nlohmann::json& config)
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, DspLoadOptions options)
{
dspData temp;
return get_dsp(config, temp);
return get_dsp(config, temp, options);
}

std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig)
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig,
DspLoadOptions options)
{
if (!std::filesystem::exists(config_filename))
throw std::runtime_error("Config file doesn't exist!\n");
std::ifstream i(config_filename);
nlohmann::json j;
i >> j;
get_dsp(j, returnedConfig);
populate_dsp_data(j, returnedConfig);

/*Copy to a new dsp_config object for get_dsp below,
since not sure if weights actually get modified as being non-const references on some
model constructors inside get_dsp(dsp_config& conf).
We need to return unmodified version of dsp_config via returnedConfig.*/
dspData conf = returnedConfig;

return get_dsp(conf);
return get_dsp(conf, options);
}

std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig)
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig, DspLoadOptions options)
{
verify_config_version(config["version"].get<std::string>());

auto architecture = config["architecture"];
nlohmann::json config_json = config["config"];
std::vector<float> weights = GetWeights(config);

// Assign values to returnedConfig
returnedConfig.version = config["version"].get<std::string>();
returnedConfig.architecture = config["architecture"].get<std::string>();
returnedConfig.config = config_json;
returnedConfig.metadata = config.value("metadata", nlohmann::json());
returnedConfig.weights = weights;
returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(config);
populate_dsp_data(config, returnedConfig);

/*Copy to a new dsp_config object for get_dsp below,
since not sure if weights actually get modified as being non-const references on some
model constructors inside get_dsp(dsp_config& conf).
We need to return unmodified version of dsp_config via returnedConfig.*/
dspData conf = returnedConfig;

return get_dsp(conf);
return get_dsp(conf, options);
}

// =============================================================================
Expand Down Expand Up @@ -224,17 +228,17 @@ std::unique_ptr<DSP> create_dsp(std::unique_ptr<ModelConfig> config, std::vector
{
auto out = config->create(std::move(weights), metadata.sample_rate);
apply_metadata(*out, metadata);
// "pre-warm" the model to settle initial conditions
// Can this be removed now that it's part of Reset()?
out->prewarm();
return out;
}

// =============================================================================
// get_dsp(dspData&) — now uses unified path
// =============================================================================

std::unique_ptr<DSP> get_dsp(dspData& conf)
namespace
{

std::unique_ptr<DSP> get_dsp_with_current_prewarm_default(dspData& conf)
{
verify_config_version(conf.version);

Expand All @@ -259,6 +263,20 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
return create_dsp(std::move(model_config), std::move(conf.weights), metadata);
}

} // anonymous namespace

std::unique_ptr<DSP> get_dsp(dspData& conf, DspLoadOptions options)
{
if (!options.prewarm.has_value())
return get_dsp_with_current_prewarm_default(conf);

ScopedPrewarmOnResetDefault scoped_prewarm_default(*options.prewarm);
auto dsp = get_dsp_with_current_prewarm_default(conf);
if (dsp != nullptr)
dsp->SetPrewarmOnReset(scoped_prewarm_default.PreviousPrewarmOnReset());
return dsp;
}

double get_sample_rate_from_nam_file(const nlohmann::json& j)
{
if (j.find("sample_rate") != j.end())
Expand Down
29 changes: 24 additions & 5 deletions NAM/get_dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <fstream>
#include <memory>
#include <optional>
#include <vector>

#include "dsp.h"
Expand Down Expand Up @@ -65,34 +66,52 @@ void verify_config_version(const std::string versionStr);
const std::string LATEST_FULLY_SUPPORTED_NAM_FILE_VERSION = "0.7.0";
const std::string EARLIEST_SUPPORTED_NAM_FILE_VERSION = "0.5.0";

/// \brief Options that control DSP loading behavior
struct DspLoadOptions
{
/// \brief Whether to override the current prewarm-on-reset context during loading
///
/// std::nullopt leaves the current thread-local context unchanged. Set this to false to avoid expensive prewarm work
/// during get_dsp(), or true to force prewarm during get_dsp(). When an override is provided, the returned model is
/// restored to the caller's previous prewarm-on-reset default before get_dsp() returns.
std::optional<bool> prewarm = std::nullopt;
};

/// \brief Get NAM from a .nam file at the provided location
/// \param config_filename Path to the .nam model file
/// \param options Loading options
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename);
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, DspLoadOptions options = DspLoadOptions());

/// \brief Get NAM from a provided configuration struct
/// \param conf DSP data structure containing model configuration and weights
/// \param options Loading options
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(dspData& conf);
std::unique_ptr<DSP> get_dsp(dspData& conf, DspLoadOptions options = DspLoadOptions());

/// \brief Get NAM from a .nam file and store its configuration
///
/// Creates an instance of DSP and also returns a dspData struct that holds the data of the model.
/// \param config_filename Path to the .nam model file
/// \param returnedConfig Output parameter that will be filled with the model data
/// \param options Loading options
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig);
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig,
DspLoadOptions options = DspLoadOptions());

/// \brief Get NAM from a provided configuration JSON object
/// \param config JSON configuration object
/// \param returnedConfig Output parameter that will be filled with the model data
/// \param options Loading options
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig);
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig,
DspLoadOptions options = DspLoadOptions());

/// \brief Get NAM from a provided configuration JSON object (convenience overload)
/// \param config JSON configuration object
/// \param options Loading options
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config);
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, DspLoadOptions options = DspLoadOptions());

/// \brief Get sample rate from a .nam file
/// \param j JSON object from the .nam file
Expand Down
2 changes: 1 addition & 1 deletion NAM/model_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct ConfigParserHelper
/// \brief Construct a DSP object from a typed config, weights, and metadata
///
/// This is the single construction path used by both JSON and binary loaders.
/// Handles construction, metadata application, and prewarm.
/// Handles construction and metadata application.
/// \param config Architecture-specific configuration (abstract base)
/// \param weights Model weights (taken by value to allow move for WaveNet)
/// \param metadata Model metadata (version, sample rate, loudness, levels)
Expand Down
7 changes: 7 additions & 0 deletions NAM/wavenet/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,13 @@ void nam::wavenet::WaveNet::SetMaxBufferSize(const int maxBufferSize)
}
}

void nam::wavenet::WaveNet::SetPrewarmOnReset(const bool prewarmOnReset)
{
DSP::SetPrewarmOnReset(prewarmOnReset);
if (this->_condition_dsp != nullptr)
this->_condition_dsp->SetPrewarmOnReset(prewarmOnReset);
}

void nam::wavenet::WaveNet::_process_condition(const int num_frames)
{
if (this->_condition_dsp == nullptr)
Expand Down
2 changes: 2 additions & 0 deletions NAM/wavenet/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class WaveNet : public DSP
/// \param num_frames Number of frames to process
void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override;

void SetPrewarmOnReset(const bool prewarmOnReset) override;

/// \brief Set model weights from a vector
/// \param weights Vector containing all model weights
void set_weights_(std::vector<float>& weights);
Expand Down
15 changes: 13 additions & 2 deletions NAM/wavenet/slimmable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,10 @@ std::unique_ptr<DSP> SlimmableWavenet::_create_wavenet_for_channels(const std::v
condition_dsp = get_dsp(_condition_dsp_json);

double sampleRate = _current_sample_rate > 0 ? _current_sample_rate : GetExpectedSampleRate();
return std::make_unique<wavenet::WaveNet>(_in_channels, *params_ptr, _head_scale, _with_head, std::nullopt,
std::move(weights), std::move(condition_dsp), sampleRate);
auto model = std::make_unique<wavenet::WaveNet>(_in_channels, *params_ptr, _head_scale, _with_head, std::nullopt,
std::move(weights), std::move(condition_dsp), sampleRate);
model->SetPrewarmOnReset(GetPrewarmOnReset());
return model;
}

void SlimmableWavenet::_rebuild_model(const std::vector<int>& target_channels)
Expand Down Expand Up @@ -481,6 +483,15 @@ void SlimmableWavenet::Reset(const double sampleRate, const int maxBufferSize)
pending->model->Reset(sampleRate, maxBufferSize);
}

void SlimmableWavenet::SetPrewarmOnReset(const bool prewarmOnReset)
{
DSP::SetPrewarmOnReset(prewarmOnReset);
if (_active_model)
_active_model->SetPrewarmOnReset(prewarmOnReset);
if (auto pending = _pending_load_acquire())
pending->model->SetPrewarmOnReset(prewarmOnReset);
}

void SlimmableWavenet::SetSlimmableSize(const double val)
{
const size_t num_arrays = _original_params.size();
Expand Down
Loading
Loading