diff --git a/NAM/container.cpp b/NAM/container.cpp index ee7d9f1..4c527af 100644 --- a/NAM/container.cpp +++ b/NAM/container.cpp @@ -108,6 +108,13 @@ void ContainerModel::SetSlimmableSize(const double val) _active_index.store(active_index, std::memory_order_release); } +int ContainerModel::GetPrewarmSamples() +{ + const size_t active_index = _active_index.load(std::memory_order_acquire); + + return _submodels[active_index].model->GetPrewarmSamples(); +} + // ============================================================================= // Config / factory // ============================================================================= diff --git a/NAM/container.h b/NAM/container.h index dccc914..fc998ff 100644 --- a/NAM/container.h +++ b/NAM/container.h @@ -38,9 +38,7 @@ class ContainerModel : public DSP, public SlimmableModel void prewarm() override; void Reset(const double sampleRate, const int maxBufferSize) override; void SetSlimmableSize(const double val) override; - -protected: - int PrewarmSamples() override { return 0; } + int GetPrewarmSamples() override; private: std::vector _submodels; diff --git a/NAM/convnet.h b/NAM/convnet.h index c1d7c1a..394389e 100644 --- a/NAM/convnet.h +++ b/NAM/convnet.h @@ -151,6 +151,8 @@ class ConvNet : public Buffer /// \param maxBufferSize Maximum number of frames to process in a single call void SetMaxBufferSize(const int maxBufferSize) override; + int GetPrewarmSamples() override { return mPrewarmSamples; }; + protected: std::vector _blocks; std::vector _block_vals; @@ -162,7 +164,6 @@ class ConvNet : public Buffer void _rewind_buffers_() override; int mPrewarmSamples = 0; // Pre-compute during initialization - int PrewarmSamples() override { return mPrewarmSamples; }; }; /// \brief Configuration for a ConvNet model diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index a4040c3..424a50f 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -31,7 +31,7 @@ void nam::DSP::prewarm() { SetMaxBufferSize(NAM_DEFAULT_MAX_BUFFER_SIZE); } - const int prewarmSamples = PrewarmSamples(); + const int prewarmSamples = GetPrewarmSamples(); if (prewarmSamples == 0) return; diff --git a/NAM/dsp.h b/NAM/dsp.h index c714a19..f6a00ec 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -131,6 +131,12 @@ class DSP /// \return true if output level is known, false otherwise bool HasOutputLevel(); + /// \brief Get how many samples should be processed for the model to be considered "warmed up" + /// + /// Override this in subclasses to specify prewarm requirements. + /// \return Number of samples needed for prewarm + virtual int GetPrewarmSamples() { return 0; }; + /// \brief General function for resetting the DSP unit /// /// This doesn't call prewarm(). If you want to do that, then you might want to use ResetAndPrewarm(). @@ -178,12 +184,6 @@ class DSP // The largest buffer I expect to be told to process: int mMaxBufferSize = 0; - /// \brief Get how many samples should be processed for the model to be considered "warmed up" - /// - /// Override this in subclasses to specify prewarm requirements. - /// \return Number of samples needed for prewarm - virtual int PrewarmSamples() { return 0; }; - /// \brief Set the maximum buffer size /// \param maxBufferSize Maximum number of frames to process in a single call virtual void SetMaxBufferSize(const int maxBufferSize); diff --git a/NAM/lstm.cpp b/NAM/lstm.cpp index 9169a7e..3213069 100644 --- a/NAM/lstm.cpp +++ b/NAM/lstm.cpp @@ -122,8 +122,9 @@ void nam::lstm::LSTM::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int } } -int nam::lstm::LSTM::PrewarmSamples() +int nam::lstm::LSTM::GetPrewarmSamples() { + // Hacky, but a half-second seems to work for most models. int result = (int)(0.5 * mExpectedSampleRate); // If the expected sample rate wasn't provided, it'll be -1. // Make sure something still happens. diff --git a/NAM/lstm.h b/NAM/lstm.h index 88b7527..607c7d5 100644 --- a/NAM/lstm.h +++ b/NAM/lstm.h @@ -78,9 +78,7 @@ class LSTM : public DSP /// \param num_frames Number of frames to process void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override; -protected: - // Hacky, but a half-second seems to work for most models. - int PrewarmSamples() override; + int GetPrewarmSamples() override; Eigen::MatrixXf _head_weight; // (out_channels x hidden_size) Eigen::VectorXf _head_bias; // (out_channels) diff --git a/NAM/wavenet/a2_fast.cpp b/NAM/wavenet/a2_fast.cpp index 62d0206..33af67a 100644 --- a/NAM/wavenet/a2_fast.cpp +++ b/NAM/wavenet/a2_fast.cpp @@ -66,10 +66,10 @@ class A2FastModel : public DSP ~A2FastModel() override = default; void process(NAM_SAMPLE** input, NAM_SAMPLE** output, int num_frames) override; + int GetPrewarmSamples() override { return _prewarm_samples; } protected: void SetMaxBufferSize(int maxBufferSize) override; - int PrewarmSamples() override { return _prewarm_samples; } private: struct Layer diff --git a/NAM/wavenet/model.cpp b/NAM/wavenet/model.cpp index eaf74ad..3c7fead 100644 --- a/NAM/wavenet/model.cpp +++ b/NAM/wavenet/model.cpp @@ -613,7 +613,7 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels, this->set_weights_(weights); // Finally, figure out how much pre-warming is needed for this model. - mPrewarmSamples = this->_condition_dsp != nullptr ? this->_condition_dsp->PrewarmSamples() : 1; + mPrewarmSamples = this->_condition_dsp != nullptr ? this->_condition_dsp->GetPrewarmSamples() : 1; for (size_t i = 0; i < this->_layer_arrays.size(); i++) mPrewarmSamples += this->_layer_arrays[i].get_receptive_field(); if (this->_post_stack_head != nullptr) diff --git a/NAM/wavenet/model.h b/NAM/wavenet/model.h index 968baf8..f0c41dc 100644 --- a/NAM/wavenet/model.h +++ b/NAM/wavenet/model.h @@ -66,6 +66,8 @@ class WaveNet : public DSP /// \param weights Iterator to the weights vector. Will be advanced as weights are consumed. void set_weights_(std::vector::iterator& weights); + int GetPrewarmSamples() override { return mPrewarmSamples; }; + protected: // Element-wise arrays: Eigen::MatrixXf _condition_input; @@ -111,7 +113,7 @@ class WaveNet : public DSP Eigen::MatrixXf _scaled_head_scratch; int mPrewarmSamples = 0; // Pre-compute during initialization - int PrewarmSamples() override { return mPrewarmSamples; }; + }; /// \brief Configuration for a WaveNet model diff --git a/NAM/wavenet/slimmable.h b/NAM/wavenet/slimmable.h index f0687c0..1c07e5d 100644 --- a/NAM/wavenet/slimmable.h +++ b/NAM/wavenet/slimmable.h @@ -51,7 +51,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel void SetSlimmableSize(const double val) override; protected: - int PrewarmSamples() override { return 0; } + int GetPrewarmSamples() override { return 0; } private: std::vector _original_params;