Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions NAM/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ void ContainerModel::SetSlimmableSize(const double val)
{
return;
}

if (!mHaveExternalSampleRate && GetMaxBufferSize() == 0)
{
_active_index.store(active_index, std::memory_order_release);
return;
}

// Setting _active_index puts the model in the RT path, so reset before doing that.
const double sr = mHaveExternalSampleRate ? mExternalSampleRate : mExpectedSampleRate;
_submodels[active_index].model->Reset(sr, GetMaxBufferSize());
Expand Down
8 changes: 4 additions & 4 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ class DSP

/// \brief General function for resetting the DSP unit
///
/// This doesn't call prewarm(). If you want to do that, then you might want to use ResetAndPrewarm().
/// See https://github.com/sdatkinson/NeuralAmpModelerCore/issues/96 for the reasoning.
/// This calls prewarm() after applying the sample rate and max buffer size.
/// \param sampleRate Current sample rate
/// \param maxBufferSize Maximum buffer size to process
virtual void Reset(const double sampleRate, const int maxBufferSize);
Expand Down Expand Up @@ -327,8 +326,9 @@ struct dspData
nlohmann::json config; ///< Model configuration JSON
nlohmann::json metadata; ///< Model metadata JSON
std::vector<float> weights; ///< Model weights
double expected_sample_rate; ///< Expected sample rate in Hz. Most NAM models implicitly assume data at some sample
///< rate. Use -1.0 for "I don't know".
double expected_sample_rate = NAM_UNKNOWN_EXPECTED_SAMPLE_RATE; ///< Expected sample rate in Hz. Most NAM models
///< implicitly assume data at some sample rate. Use
///< -1.0 for "I don't know".
};

/// \brief Verify that the config version is supported by this plugin version
Expand Down
151 changes: 97 additions & 54 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,27 @@
#include "json.hpp"
#include "get_dsp.h"
#include "model_config.h"
#include "slimmable.h"

namespace nam
{
std::vector<float> GetWeights(nlohmann::json const& j);

namespace
{

struct LoadOptions
{
std::optional<double> expectedSampleRate;
std::optional<int> maxBufferSize;
std::optional<double> slimmableSize;

bool requires_initial_reset() const
{
return expectedSampleRate.has_value() || maxBufferSize.has_value() || slimmableSize.has_value();
}
};

class CoreVersionSupportChecker : public IVersionSupportChecker
{
public:
Expand Down Expand Up @@ -52,6 +67,65 @@ std::mutex& version_support_registry_mutex()
return registry_mutex;
}

dspData parse_dsp_data(const nlohmann::json& config, std::optional<double> expectedSampleRate)
{
verify_config_version(config["version"].get<std::string>());

dspData out;
out.version = config["version"].get<std::string>();
out.architecture = config["architecture"].get<std::string>();
out.config = config["config"];
out.metadata = config.value("metadata", nlohmann::json());
out.weights = GetWeights(config);
out.expected_sample_rate = expectedSampleRate.value_or(nam::get_sample_rate_from_nam_file(config));
return out;
}

void apply_initial_slimmable_size(DSP& dsp, const double slimmableSize)
{
auto* slimmable = dynamic_cast<SlimmableModel*>(&dsp);
if (slimmable == nullptr)
throw std::runtime_error("Cannot set slimmable size on a model that is not slimmable.");
slimmable->SetSlimmableSize(slimmableSize);
}

void apply_metadata(DSP& dsp, const ModelMetadata& metadata)
{
if (metadata.loudness.has_value())
dsp.SetLoudness(metadata.loudness.value());
if (metadata.input_level.has_value())
dsp.SetInputLevel(metadata.input_level.value());
if (metadata.output_level.has_value())
dsp.SetOutputLevel(metadata.output_level.value());
}

void configure_initial_state(DSP& dsp, const ModelMetadata& metadata, const LoadOptions& options)
{
if (options.slimmableSize.has_value())
apply_initial_slimmable_size(dsp, options.slimmableSize.value());

if (options.requires_initial_reset())
{
const double sampleRate = options.expectedSampleRate.value_or(metadata.sample_rate);
const int maxBufferSize = options.maxBufferSize.value_or(NAM_DEFAULT_MAX_BUFFER_SIZE);
dsp.Reset(sampleRate, maxBufferSize);
}
else
{
// Preserve the historical load behavior when no load-time configuration is requested.
dsp.prewarm();
}
}

std::unique_ptr<DSP> create_dsp_with_options(std::unique_ptr<ModelConfig> config, std::vector<float> weights,
const ModelMetadata& metadata, const LoadOptions& options)
{
auto out = config->create(std::move(weights), metadata.sample_rate);
apply_metadata(*out, metadata);
configure_initial_state(*out, metadata, options);
return out;
}

} // namespace

Version ParseVersion(const std::string& versionStr)
Expand Down Expand Up @@ -139,59 +213,45 @@ std::vector<float> GetWeights(nlohmann::json const& j)
throw std::runtime_error("Corrupted model file is missing weights.");
}

std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename)
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, std::optional<double> expectedSampleRate,
std::optional<int> maxBufferSize, std::optional<double> slimmableSize)
{
dspData temp;
return get_dsp(config_filename, temp);
return get_dsp(config_filename, temp, expectedSampleRate, maxBufferSize, slimmableSize);
}

std::unique_ptr<DSP> get_dsp(const nlohmann::json& config)
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, std::optional<double> expectedSampleRate,
std::optional<int> maxBufferSize, std::optional<double> slimmableSize)
{
dspData temp;
return get_dsp(config, temp);
return get_dsp(config, temp, expectedSampleRate, maxBufferSize, slimmableSize);
}

std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig)
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig,
std::optional<double> expectedSampleRate, std::optional<int> maxBufferSize,
std::optional<double> slimmableSize)
{
if (!std::filesystem::exists(config_filename))
throw std::runtime_error("Config file doesn't exist!\n");
std::ifstream i(config_filename);
nlohmann::json j;
i >> j;
get_dsp(j, returnedConfig);

/*Copy to a new dsp_config object for get_dsp below,
since not sure if weights actually get modified as being non-const references on some
model constructors inside get_dsp(dsp_config& conf).
We need to return unmodified version of dsp_config via returnedConfig.*/
dspData conf = returnedConfig;

return get_dsp(conf);
return get_dsp(j, returnedConfig, expectedSampleRate, maxBufferSize, slimmableSize);
}

std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig)
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig,
std::optional<double> expectedSampleRate, std::optional<int> maxBufferSize,
std::optional<double> slimmableSize)
{
verify_config_version(config["version"].get<std::string>());

auto architecture = config["architecture"];
nlohmann::json config_json = config["config"];
std::vector<float> weights = GetWeights(config);

// Assign values to returnedConfig
returnedConfig.version = config["version"].get<std::string>();
returnedConfig.architecture = config["architecture"].get<std::string>();
returnedConfig.config = config_json;
returnedConfig.metadata = config.value("metadata", nlohmann::json());
returnedConfig.weights = weights;
returnedConfig.expected_sample_rate = nam::get_sample_rate_from_nam_file(config);
returnedConfig = parse_dsp_data(config, expectedSampleRate);

/*Copy to a new dsp_config object for get_dsp below,
since not sure if weights actually get modified as being non-const references on some
model constructors inside get_dsp(dsp_config& conf).
We need to return unmodified version of dsp_config via returnedConfig.*/
dspData conf = returnedConfig;

return get_dsp(conf);
return get_dsp(conf, expectedSampleRate, maxBufferSize, slimmableSize);
}

// =============================================================================
Expand All @@ -204,44 +264,27 @@ std::unique_ptr<ModelConfig> parse_model_config_json(const std::string& architec
return ConfigParserRegistry::instance().parse(architecture, config, sample_rate);
}

namespace
{

void apply_metadata(DSP& dsp, const ModelMetadata& metadata)
{
if (metadata.loudness.has_value())
dsp.SetLoudness(metadata.loudness.value());
if (metadata.input_level.has_value())
dsp.SetInputLevel(metadata.input_level.value());
if (metadata.output_level.has_value())
dsp.SetOutputLevel(metadata.output_level.value());
}

} // anonymous namespace

std::unique_ptr<DSP> create_dsp(std::unique_ptr<ModelConfig> config, std::vector<float> weights,
const ModelMetadata& metadata)
{
auto out = config->create(std::move(weights), metadata.sample_rate);
apply_metadata(*out, metadata);
// "pre-warm" the model to settle initial conditions
// Can this be removed now that it's part of Reset()?
out->prewarm();
return out;
return create_dsp_with_options(std::move(config), std::move(weights), metadata, LoadOptions{});
}

// =============================================================================
// get_dsp(dspData&) — now uses unified path
// =============================================================================

std::unique_ptr<DSP> get_dsp(dspData& conf)
std::unique_ptr<DSP> get_dsp(dspData& conf, std::optional<double> expectedSampleRate, std::optional<int> maxBufferSize,
std::optional<double> slimmableSize)
{
verify_config_version(conf.version);
const double effectiveSampleRate = expectedSampleRate.value_or(conf.expected_sample_rate);
const LoadOptions options{expectedSampleRate, maxBufferSize, slimmableSize};

// Extract metadata from JSON
ModelMetadata metadata;
metadata.version = conf.version;
metadata.sample_rate = conf.expected_sample_rate;
metadata.sample_rate = effectiveSampleRate;

if (!conf.metadata.is_null())
{
Expand All @@ -255,8 +298,8 @@ std::unique_ptr<DSP> get_dsp(dspData& conf)
metadata.output_level = extract("output_level_dbu");
}

auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, conf.expected_sample_rate);
return create_dsp(std::move(model_config), std::move(conf.weights), metadata);
auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, effectiveSampleRate);
return create_dsp_with_options(std::move(model_config), std::move(conf.weights), metadata, options);
}

double get_sample_rate_from_nam_file(const nlohmann::json& j)
Expand Down
39 changes: 34 additions & 5 deletions NAM/get_dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <fstream>
#include <memory>
#include <optional>
#include <vector>

#include "dsp.h"
Expand Down Expand Up @@ -67,32 +68,60 @@ const std::string EARLIEST_SUPPORTED_NAM_FILE_VERSION = "0.5.0";

/// \brief Get NAM from a .nam file at the provided location
/// \param config_filename Path to the .nam model file
/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default
/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE
/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename);
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename,
std::optional<double> expectedSampleRate = std::nullopt,
std::optional<int> maxBufferSize = std::nullopt,
std::optional<double> slimmableSize = std::nullopt);

/// \brief Get NAM from a provided configuration struct
/// \param conf DSP data structure containing model configuration and weights
/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the config default
/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE
/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(dspData& conf);
std::unique_ptr<DSP> get_dsp(dspData& conf, std::optional<double> expectedSampleRate = std::nullopt,
std::optional<int> maxBufferSize = std::nullopt,
std::optional<double> slimmableSize = std::nullopt);

/// \brief Get NAM from a .nam file and store its configuration
///
/// Creates an instance of DSP and also returns a dspData struct that holds the data of the model.
/// \param config_filename Path to the .nam model file
/// \param returnedConfig Output parameter that will be filled with the model data
/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default
/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE
/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig);
std::unique_ptr<DSP> get_dsp(const std::filesystem::path config_filename, dspData& returnedConfig,
std::optional<double> expectedSampleRate = std::nullopt,
std::optional<int> maxBufferSize = std::nullopt,
std::optional<double> slimmableSize = std::nullopt);

/// \brief Get NAM from a provided configuration JSON object
/// \param config JSON configuration object
/// \param returnedConfig Output parameter that will be filled with the model data
/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default
/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE
/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig);
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConfig,
std::optional<double> expectedSampleRate = std::nullopt,
std::optional<int> maxBufferSize = std::nullopt,
std::optional<double> slimmableSize = std::nullopt);

/// \brief Get NAM from a provided configuration JSON object (convenience overload)
/// \param config JSON configuration object
/// \param expectedSampleRate Expected sample rate to configure the model with; std::nullopt uses the file default
/// \param maxBufferSize Maximum buffer size to configure the model with; std::nullopt uses NAM_DEFAULT_MAX_BUFFER_SIZE
/// \param slimmableSize Slimmable size to configure the model with; std::nullopt uses the model default
/// \return Unique pointer to a DSP object
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config);
std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, std::optional<double> expectedSampleRate = std::nullopt,
std::optional<int> maxBufferSize = std::nullopt,
std::optional<double> slimmableSize = std::nullopt);

/// \brief Get sample rate from a .nam file
/// \param j JSON object from the .nam file
Expand Down
6 changes: 6 additions & 0 deletions NAM/wavenet/slimmable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ void SlimmableWavenet::SetSlimmableSize(const double val)
target[i] = ratio_to_channels(val, allowed);
}

if (_current_buffer_size <= 0)
{
_rebuild_model(target);
return;
}

_stage_rebuild_model(target);
}

Expand Down
3 changes: 3 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ int main()
test_get_dsp::test_gets_output_level();
test_get_dsp::test_null_input_level();
test_get_dsp::test_null_output_level();
test_get_dsp::test_get_dsp_without_load_options_preserves_prewarm_only();
test_get_dsp::test_get_dsp_applies_load_options();
test_get_dsp::test_get_dsp_applies_slimmable_option_before_reset_with_defaults();
test_get_dsp::test_version_patch_one_beyond_supported();
test_get_dsp::test_version_minor_one_beyond_supported();
test_get_dsp::test_version_too_early();
Expand Down
Loading
Loading