Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions NAM/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ void ContainerModel::SetSlimmableSize(const double val)
_active_index.store(active_index, std::memory_order_release);
}

int ContainerModel::GetPrewarmSamples()
{
const size_t active_index = _active_index.load(std::memory_order_acquire);

return _submodels[active_index].model->GetPrewarmSamples();
}

// =============================================================================
// Config / factory
// =============================================================================
Expand Down
4 changes: 1 addition & 3 deletions NAM/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class ContainerModel : public DSP, public SlimmableModel
void prewarm() override;
void Reset(const double sampleRate, const int maxBufferSize) override;
void SetSlimmableSize(const double val) override;

protected:
int PrewarmSamples() override { return 0; }
int GetPrewarmSamples() override;

private:
std::vector<Submodel> _submodels;
Expand Down
3 changes: 2 additions & 1 deletion NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class ConvNet : public Buffer
/// \param maxBufferSize Maximum number of frames to process in a single call
void SetMaxBufferSize(const int maxBufferSize) override;

int GetPrewarmSamples() override { return mPrewarmSamples; };

protected:
std::vector<ConvNetBlock> _blocks;
std::vector<Eigen::MatrixXf> _block_vals;
Expand All @@ -162,7 +164,6 @@ class ConvNet : public Buffer
void _rewind_buffers_() override;

int mPrewarmSamples = 0; // Pre-compute during initialization
int PrewarmSamples() override { return mPrewarmSamples; };
};

/// \brief Configuration for a ConvNet model
Expand Down
2 changes: 1 addition & 1 deletion NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void nam::DSP::prewarm()
{
SetMaxBufferSize(NAM_DEFAULT_MAX_BUFFER_SIZE);
}
const int prewarmSamples = PrewarmSamples();
const int prewarmSamples = GetPrewarmSamples();
if (prewarmSamples == 0)
return;

Expand Down
12 changes: 6 additions & 6 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class DSP
/// \return true if output level is known, false otherwise
bool HasOutputLevel();

/// \brief Get how many samples should be processed for the model to be considered "warmed up"
///
/// Override this in subclasses to specify prewarm requirements.
/// \return Number of samples needed for prewarm
virtual int GetPrewarmSamples() { return 0; };

/// \brief General function for resetting the DSP unit
///
/// This doesn't call prewarm(). If you want to do that, then you might want to use ResetAndPrewarm().
Expand Down Expand Up @@ -178,12 +184,6 @@ class DSP
// The largest buffer I expect to be told to process:
int mMaxBufferSize = 0;

/// \brief Get how many samples should be processed for the model to be considered "warmed up"
///
/// Override this in subclasses to specify prewarm requirements.
/// \return Number of samples needed for prewarm
virtual int PrewarmSamples() { return 0; };

/// \brief Set the maximum buffer size
/// \param maxBufferSize Maximum number of frames to process in a single call
virtual void SetMaxBufferSize(const int maxBufferSize);
Expand Down
3 changes: 2 additions & 1 deletion NAM/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ void nam::lstm::LSTM::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int
}
}

int nam::lstm::LSTM::PrewarmSamples()
int nam::lstm::LSTM::GetPrewarmSamples()
{
// Hacky, but a half-second seems to work for most models.
int result = (int)(0.5 * mExpectedSampleRate);
// If the expected sample rate wasn't provided, it'll be -1.
// Make sure something still happens.
Expand Down
4 changes: 1 addition & 3 deletions NAM/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ class LSTM : public DSP
/// \param num_frames Number of frames to process
void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override;

protected:
// Hacky, but a half-second seems to work for most models.
int PrewarmSamples() override;
int GetPrewarmSamples() override;

Eigen::MatrixXf _head_weight; // (out_channels x hidden_size)
Eigen::VectorXf _head_bias; // (out_channels)
Expand Down
2 changes: 1 addition & 1 deletion NAM/wavenet/a2_fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class A2FastModel : public DSP
~A2FastModel() override = default;

void process(NAM_SAMPLE** input, NAM_SAMPLE** output, int num_frames) override;
int GetPrewarmSamples() override { return _prewarm_samples; }

protected:
void SetMaxBufferSize(int maxBufferSize) override;
int PrewarmSamples() override { return _prewarm_samples; }

private:
struct Layer
Expand Down
2 changes: 1 addition & 1 deletion NAM/wavenet/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ nam::wavenet::WaveNet::WaveNet(const int in_channels,
this->set_weights_(weights);

// Finally, figure out how much pre-warming is needed for this model.
mPrewarmSamples = this->_condition_dsp != nullptr ? this->_condition_dsp->PrewarmSamples() : 1;
mPrewarmSamples = this->_condition_dsp != nullptr ? this->_condition_dsp->GetPrewarmSamples() : 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
mPrewarmSamples += this->_layer_arrays[i].get_receptive_field();
if (this->_post_stack_head != nullptr)
Expand Down
4 changes: 3 additions & 1 deletion NAM/wavenet/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class WaveNet : public DSP
/// \param weights Iterator to the weights vector. Will be advanced as weights are consumed.
void set_weights_(std::vector<float>::iterator& weights);

int GetPrewarmSamples() override { return mPrewarmSamples; };

protected:
// Element-wise arrays:
Eigen::MatrixXf _condition_input;
Expand Down Expand Up @@ -111,7 +113,7 @@ class WaveNet : public DSP
Eigen::MatrixXf _scaled_head_scratch;

int mPrewarmSamples = 0; // Pre-compute during initialization
int PrewarmSamples() override { return mPrewarmSamples; };

};

/// \brief Configuration for a WaveNet model
Expand Down
2 changes: 1 addition & 1 deletion NAM/wavenet/slimmable.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class SlimmableWavenet : public DSP, public SlimmableModel
void SetSlimmableSize(const double val) override;

protected:
int PrewarmSamples() override { return 0; }
int GetPrewarmSamples() override { return 0; }

private:
std::vector<wavenet::LayerArrayParams> _original_params;
Expand Down
Loading