Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ namespace wavenet
class WaveNet;
} // namespace wavenet


/// \brief Base class for all DSP models
///
/// DSP provides the common interface for all neural network-based audio processing models.
Expand Down
255 changes: 253 additions & 2 deletions NAM/linear.cpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,171 @@
#include "linear.h"

#include <algorithm>
#include <cctype>
#include <complex>
#include <stdexcept>

#include "registry.h"

#include <unsupported/Eigen/FFT>

namespace
{
constexpr int _LINEAR_AUTO_DIRECT_MAX_TAPS = 256;
constexpr int _LINEAR_FFT_SMALL_BLOCK_SIZE = 256;
constexpr int _LINEAR_FFT_MEDIUM_BLOCK_SIZE = 512;
constexpr int _LINEAR_FFT_LARGE_BLOCK_SIZE = 1024;

int _ceil_div(const int numerator, const int denominator)
{
return (numerator + denominator - 1) / denominator;
}

int _choose_linear_fft_block_size(const int receptive_field)
{
if (receptive_field <= 2048)
return _LINEAR_FFT_SMALL_BLOCK_SIZE;
if (receptive_field <= 8192)
return _LINEAR_FFT_MEDIUM_BLOCK_SIZE;
return _LINEAR_FFT_LARGE_BLOCK_SIZE;
}

} // namespace

struct nam::LinearFFTState
{
using Complex = std::complex<float>;

struct ChannelState
{
std::vector<float> input_time;
std::vector<std::vector<Complex>> input_spectra;
std::vector<float> output_ring;
int input_pos = 0;
int spectrum_write_index = 0;
};

Eigen::FFT<float> fft;
int block_size = 0;
int fft_size = 0;
int direct_taps = 0;
int num_partitions = 0;
int output_ring_size = 0;
long long sample_index = 0;
std::vector<std::vector<Complex>> kernel_spectra;
std::vector<ChannelState> channels;
std::vector<Complex> accumulator;
std::vector<float> ifft_time;
};

nam::Linear::Linear(const int in_channels, const int out_channels, const int receptive_field, const bool _bias,
const std::vector<float>& weights, const double expected_sample_rate)
const std::vector<float>& weights, const double expected_sample_rate,
const LinearImplementation implementation)
: nam::Buffer(in_channels, out_channels, receptive_field, expected_sample_rate)
, _requested_implementation(implementation)
, _active_implementation(LinearImplementation::Direct)
{
if ((int)weights.size() != (receptive_field + (_bias ? 1 : 0)))
throw std::runtime_error(
"Params vector does not match expected size based "
"on architecture parameters");

this->_impulse_response.assign(weights.begin(), weights.begin() + receptive_field);
this->_weight.resize(this->_receptive_field);
// Pass in in reverse order so that dot products work out of the box.
for (int i = 0; i < this->_receptive_field; i++)
this->_weight(i) = weights[receptive_field - 1 - i];
this->_bias = _bias ? weights[receptive_field] : (float)0.0;

this->_configure_implementation();
}

nam::Linear::~Linear() = default;

void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames)
{
if (this->_active_implementation == LinearImplementation::FFT)
this->_process_fft(input, output, num_frames);
else
this->_process_direct(input, output, num_frames);
}

void nam::Linear::SetMaxBufferSize(const int maxBufferSize)
{
nam::Buffer::SetMaxBufferSize(maxBufferSize);
this->_configure_implementation();
}

void nam::Linear::_configure_implementation()
{
if (this->_requested_implementation == LinearImplementation::Direct)
this->_active_implementation = LinearImplementation::Direct;
else if (this->_requested_implementation == LinearImplementation::FFT)
this->_active_implementation = LinearImplementation::FFT;
else
this->_active_implementation =
this->_receptive_field <= _LINEAR_AUTO_DIRECT_MAX_TAPS ? LinearImplementation::Direct : LinearImplementation::FFT;

if (this->_active_implementation == LinearImplementation::FFT)
this->_configure_fft_state();
else
this->_fft_state.reset();
}

void nam::Linear::_configure_fft_state()
{
this->_fft_state = std::make_unique<LinearFFTState>();
auto& state = *this->_fft_state;

state.block_size = _choose_linear_fft_block_size(this->_receptive_field);
state.fft_size = 2 * state.block_size;
state.direct_taps = std::min(this->_receptive_field, state.block_size);
state.num_partitions = this->_receptive_field > state.direct_taps
? _ceil_div(this->_receptive_field - state.direct_taps, state.block_size)
: 0;
state.output_ring_size = 4 * state.block_size;
state.sample_index = 0;

this->_fft_direct_weight.resize(state.direct_taps);
for (int i = 0; i < state.direct_taps; i++)
this->_fft_direct_weight(i) = this->_impulse_response[state.direct_taps - 1 - i];

state.kernel_spectra.assign(state.num_partitions, std::vector<LinearFFTState::Complex>(state.fft_size));
std::vector<float> kernel_time(state.fft_size, 0.0f);
for (int partition = 0; partition < state.num_partitions; partition++)
{
std::fill(kernel_time.begin(), kernel_time.end(), 0.0f);
const int start = state.direct_taps + partition * state.block_size;
const int partition_size = std::min(state.block_size, this->_receptive_field - start);
for (int i = 0; i < partition_size; i++)
kernel_time[i] = this->_impulse_response[start + i];
state.fft.fwd(state.kernel_spectra[partition].data(), kernel_time.data(), state.fft_size);
}

const int channels_to_process = std::min(NumInputChannels(), NumOutputChannels());
state.channels.resize(channels_to_process);
for (auto& channel : state.channels)
{
channel.input_time.assign(state.fft_size, 0.0f);
channel.input_spectra.assign(
state.num_partitions, std::vector<LinearFFTState::Complex>(state.fft_size, LinearFFTState::Complex{}));
channel.output_ring.assign(state.output_ring_size, 0.0f);
channel.input_pos = 0;
channel.spectrum_write_index = 0;
}
state.accumulator.assign(state.fft_size, LinearFFTState::Complex{});
state.ifft_time.assign(state.fft_size, 0.0f);

if (state.num_partitions > 0)
{
std::vector<LinearFFTState::Complex> warm_spectrum(state.fft_size);
std::vector<float> warm_time(state.fft_size, 0.0f);
state.fft.fwd(warm_spectrum.data(), warm_time.data(), state.fft_size);
state.fft.inv(warm_time.data(), warm_spectrum.data(), state.fft_size);
}
}

void nam::Linear::_process_direct(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames)
{
this->nam::Buffer::_update_buffers_(input, num_frames);

Expand Down Expand Up @@ -54,6 +198,111 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num
nam::Buffer::_advance_input_buffer_(num_frames);
}

void nam::Linear::_process_fft(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames)
{
this->nam::Buffer::_update_buffers_(input, num_frames);

const int in_channels = NumInputChannels();
const int out_channels = NumOutputChannels();
const int channels_to_process = std::min(in_channels, out_channels);
auto& state = *this->_fft_state;
const int direct_taps = state.direct_taps;

for (int i = 0; i < num_frames; i++)
{
const long direct_offset = this->_input_buffer_offset - direct_taps + i + 1;
for (int ch = 0; ch < channels_to_process; ch++)
{
const int ring_index = (int)(state.sample_index % state.output_ring_size);
const float tail = state.channels[ch].output_ring[ring_index];
state.channels[ch].output_ring[ring_index] = 0.0f;

auto input_vec = Eigen::Map<const Eigen::VectorXf>(&this->_input_buffers[ch][direct_offset], direct_taps);
output[ch][i] = this->_bias + this->_fft_direct_weight.dot(input_vec) + tail;

if (state.num_partitions > 0)
{
auto& channel = state.channels[ch];
channel.input_time[channel.input_pos] = (float)input[ch][i];
channel.input_pos++;
if (channel.input_pos == state.block_size)
this->_run_fft_block(ch);
}
}

for (int ch = channels_to_process; ch < out_channels; ch++)
output[ch][i] = (NAM_SAMPLE)0.0;

state.sample_index++;
}

nam::Buffer::_advance_input_buffer_(num_frames);
}

void nam::Linear::_run_fft_block(const int channel_index)
{
auto& state = *this->_fft_state;
auto& channel = state.channels[channel_index];

auto& current_spectrum = channel.input_spectra[channel.spectrum_write_index];
state.fft.fwd(current_spectrum.data(), channel.input_time.data(), state.fft_size);

std::fill(state.accumulator.begin(), state.accumulator.end(), LinearFFTState::Complex{});
for (int partition = 0; partition < state.num_partitions; partition++)
{
int input_spectrum_index = channel.spectrum_write_index - partition;
if (input_spectrum_index < 0)
input_spectrum_index += state.num_partitions;
const auto& input_spectrum = channel.input_spectra[input_spectrum_index];
const auto& kernel_spectrum = state.kernel_spectra[partition];
for (int bin = 0; bin < state.fft_size; bin++)
state.accumulator[bin] += input_spectrum[bin] * kernel_spectrum[bin];
}

state.fft.inv(state.ifft_time.data(), state.accumulator.data(), state.fft_size);

const long long block_start = state.sample_index - state.block_size + 1;
const long long output_start = block_start + state.direct_taps;
auto& output_ring = channel.output_ring;
for (int i = 0; i < state.fft_size - 1; i++)
{
const int ring_index = (int)((output_start + i) % state.output_ring_size);
output_ring[ring_index] += state.ifft_time[i];
}

std::fill(channel.input_time.begin(), channel.input_time.begin() + state.block_size, 0.0f);
channel.input_pos = 0;
channel.spectrum_write_index++;
if (channel.spectrum_write_index == state.num_partitions)
channel.spectrum_write_index = 0;
}

nam::LinearImplementation nam::linear::parse_implementation(const std::string& implementation)
{
std::string normalized = implementation;
std::transform(
normalized.begin(), normalized.end(), normalized.begin(), [](unsigned char c) { return (char)std::tolower(c); });

if (normalized == "auto")
return LinearImplementation::Auto;
if (normalized == "direct" || normalized == "legacy" || normalized == "old")
return LinearImplementation::Direct;
if (normalized == "fft" || normalized == "partitioned_fft" || normalized == "partitioned-fft")
return LinearImplementation::FFT;
throw std::runtime_error("Unsupported Linear implementation: " + implementation);
}

std::string nam::linear::implementation_to_string(const LinearImplementation implementation)
{
switch (implementation)
{
case LinearImplementation::Auto: return "auto";
case LinearImplementation::Direct: return "direct";
case LinearImplementation::FFT: return "fft";
}
throw std::runtime_error("Unsupported Linear implementation enum");
}

nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config)
{
LinearConfig c;
Expand All @@ -62,12 +311,14 @@ nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& c
// Default to 1 channel in/out for backward compatibility
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
c.implementation = parse_implementation(config.value("implementation", "auto"));
return c;
}

std::unique_ptr<nam::DSP> nam::linear::LinearConfig::create(std::vector<float> weights, double sampleRate)
{
return std::make_unique<nam::Linear>(in_channels, out_channels, receptive_field, bias, weights, sampleRate);
return std::make_unique<nam::Linear>(
in_channels, out_channels, receptive_field, bias, weights, sampleRate, implementation);
}

std::unique_ptr<nam::ModelConfig> nam::linear::create_config(const nlohmann::json& config, double sampleRate)
Expand Down
42 changes: 41 additions & 1 deletion NAM/linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
namespace nam
{

struct LinearFFTState;

/// \brief Selects the convolution engine used by Linear models.
enum class LinearImplementation
{
Auto, ///< Choose direct or FFT convolution from the impulse-response length.
Direct, ///< Legacy per-sample direct convolution.
FFT ///< Zero-latency partitioned FFT convolution.
};

/// \brief Basic linear model
///
/// Implements a simple linear convolution, (i.e. an impulse response).
Expand All @@ -18,18 +28,41 @@ class Linear : public Buffer
/// \param _bias Whether to use bias
/// \param weights Model weights (impulse response coefficients)
/// \param expected_sample_rate Expected sample rate in Hz (-1.0 if unknown)
/// \param implementation Convolution implementation to use
Linear(const int in_channels, const int out_channels, const int receptive_field, const bool _bias,
const std::vector<float>& weights, const double expected_sample_rate = -1.0);
const std::vector<float>& weights, const double expected_sample_rate = -1.0,
const LinearImplementation implementation = LinearImplementation::Auto);

~Linear() override;

/// \brief Process audio frames
/// \param input Input audio buffers
/// \param output Output audio buffers
/// \param num_frames Number of frames to process
void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override;

LinearImplementation GetRequestedImplementation() const { return _requested_implementation; }
LinearImplementation GetActiveImplementation() const { return _active_implementation; }

protected:
void SetMaxBufferSize(const int maxBufferSize) override;

protected:
Eigen::VectorXf _weight;
Eigen::VectorXf _fft_direct_weight;
float _bias;

private:
std::vector<float> _impulse_response;
LinearImplementation _requested_implementation;
LinearImplementation _active_implementation;
std::unique_ptr<LinearFFTState> _fft_state;

void _configure_implementation();
void _configure_fft_state();
void _process_direct(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames);
void _process_fft(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames);
void _run_fft_block(const int channel);
};

namespace linear
Expand All @@ -42,10 +75,17 @@ struct LinearConfig : public ModelConfig
bool bias;
int in_channels;
int out_channels;
LinearImplementation implementation = LinearImplementation::Auto;

std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
};

/// \brief Parse a Linear implementation string.
LinearImplementation parse_implementation(const std::string& implementation);

/// \brief String name for a Linear implementation.
std::string implementation_to_string(const LinearImplementation implementation);

/// \brief Parse Linear configuration from JSON
/// \param config JSON configuration object
/// \return LinearConfig
Expand Down
Loading
Loading