diff --git a/NAM/dsp.h b/NAM/dsp.h index 7f66134..c714a19 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -39,7 +39,6 @@ namespace wavenet class WaveNet; } // namespace wavenet - /// \brief Base class for all DSP models /// /// DSP provides the common interface for all neural network-based audio processing models. diff --git a/NAM/linear.cpp b/NAM/linear.cpp index 6ffe040..186ef8a 100644 --- a/NAM/linear.cpp +++ b/NAM/linear.cpp @@ -1,27 +1,171 @@ #include "linear.h" #include +#include +#include #include #include "registry.h" +#include + +namespace +{ +constexpr int _LINEAR_AUTO_DIRECT_MAX_TAPS = 256; +constexpr int _LINEAR_FFT_SMALL_BLOCK_SIZE = 256; +constexpr int _LINEAR_FFT_MEDIUM_BLOCK_SIZE = 512; +constexpr int _LINEAR_FFT_LARGE_BLOCK_SIZE = 1024; + +int _ceil_div(const int numerator, const int denominator) +{ + return (numerator + denominator - 1) / denominator; +} + +int _choose_linear_fft_block_size(const int receptive_field) +{ + if (receptive_field <= 2048) + return _LINEAR_FFT_SMALL_BLOCK_SIZE; + if (receptive_field <= 8192) + return _LINEAR_FFT_MEDIUM_BLOCK_SIZE; + return _LINEAR_FFT_LARGE_BLOCK_SIZE; +} + +} // namespace + +struct nam::LinearFFTState +{ + using Complex = std::complex; + + struct ChannelState + { + std::vector input_time; + std::vector> input_spectra; + std::vector output_ring; + int input_pos = 0; + int spectrum_write_index = 0; + }; + + Eigen::FFT fft; + int block_size = 0; + int fft_size = 0; + int direct_taps = 0; + int num_partitions = 0; + int output_ring_size = 0; + long long sample_index = 0; + std::vector> kernel_spectra; + std::vector channels; + std::vector accumulator; + std::vector ifft_time; +}; + nam::Linear::Linear(const int in_channels, const int out_channels, const int receptive_field, const bool _bias, - const std::vector& weights, const double expected_sample_rate) + const std::vector& weights, const double expected_sample_rate, + const LinearImplementation implementation) : nam::Buffer(in_channels, out_channels, receptive_field, expected_sample_rate) +, _requested_implementation(implementation) +, _active_implementation(LinearImplementation::Direct) { if ((int)weights.size() != (receptive_field + (_bias ? 1 : 0))) throw std::runtime_error( "Params vector does not match expected size based " "on architecture parameters"); + this->_impulse_response.assign(weights.begin(), weights.begin() + receptive_field); this->_weight.resize(this->_receptive_field); // Pass in in reverse order so that dot products work out of the box. for (int i = 0; i < this->_receptive_field; i++) this->_weight(i) = weights[receptive_field - 1 - i]; this->_bias = _bias ? weights[receptive_field] : (float)0.0; + + this->_configure_implementation(); } +nam::Linear::~Linear() = default; + void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) +{ + if (this->_active_implementation == LinearImplementation::FFT) + this->_process_fft(input, output, num_frames); + else + this->_process_direct(input, output, num_frames); +} + +void nam::Linear::SetMaxBufferSize(const int maxBufferSize) +{ + nam::Buffer::SetMaxBufferSize(maxBufferSize); + this->_configure_implementation(); +} + +void nam::Linear::_configure_implementation() +{ + if (this->_requested_implementation == LinearImplementation::Direct) + this->_active_implementation = LinearImplementation::Direct; + else if (this->_requested_implementation == LinearImplementation::FFT) + this->_active_implementation = LinearImplementation::FFT; + else + this->_active_implementation = + this->_receptive_field <= _LINEAR_AUTO_DIRECT_MAX_TAPS ? LinearImplementation::Direct : LinearImplementation::FFT; + + if (this->_active_implementation == LinearImplementation::FFT) + this->_configure_fft_state(); + else + this->_fft_state.reset(); +} + +void nam::Linear::_configure_fft_state() +{ + this->_fft_state = std::make_unique(); + auto& state = *this->_fft_state; + + state.block_size = _choose_linear_fft_block_size(this->_receptive_field); + state.fft_size = 2 * state.block_size; + state.direct_taps = std::min(this->_receptive_field, state.block_size); + state.num_partitions = this->_receptive_field > state.direct_taps + ? _ceil_div(this->_receptive_field - state.direct_taps, state.block_size) + : 0; + state.output_ring_size = 4 * state.block_size; + state.sample_index = 0; + + this->_fft_direct_weight.resize(state.direct_taps); + for (int i = 0; i < state.direct_taps; i++) + this->_fft_direct_weight(i) = this->_impulse_response[state.direct_taps - 1 - i]; + + state.kernel_spectra.assign(state.num_partitions, std::vector(state.fft_size)); + std::vector kernel_time(state.fft_size, 0.0f); + for (int partition = 0; partition < state.num_partitions; partition++) + { + std::fill(kernel_time.begin(), kernel_time.end(), 0.0f); + const int start = state.direct_taps + partition * state.block_size; + const int partition_size = std::min(state.block_size, this->_receptive_field - start); + for (int i = 0; i < partition_size; i++) + kernel_time[i] = this->_impulse_response[start + i]; + state.fft.fwd(state.kernel_spectra[partition].data(), kernel_time.data(), state.fft_size); + } + + const int channels_to_process = std::min(NumInputChannels(), NumOutputChannels()); + state.channels.resize(channels_to_process); + for (auto& channel : state.channels) + { + channel.input_time.assign(state.fft_size, 0.0f); + channel.input_spectra.assign( + state.num_partitions, std::vector(state.fft_size, LinearFFTState::Complex{})); + channel.output_ring.assign(state.output_ring_size, 0.0f); + channel.input_pos = 0; + channel.spectrum_write_index = 0; + } + state.accumulator.assign(state.fft_size, LinearFFTState::Complex{}); + state.ifft_time.assign(state.fft_size, 0.0f); + + if (state.num_partitions > 0) + { + std::vector warm_spectrum(state.fft_size); + std::vector warm_time(state.fft_size, 0.0f); + state.fft.fwd(warm_spectrum.data(), warm_time.data(), state.fft_size); + state.fft.inv(warm_time.data(), warm_spectrum.data(), state.fft_size); + } +} + +void nam::Linear::_process_direct(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) { this->nam::Buffer::_update_buffers_(input, num_frames); @@ -54,6 +198,111 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num nam::Buffer::_advance_input_buffer_(num_frames); } +void nam::Linear::_process_fft(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) +{ + this->nam::Buffer::_update_buffers_(input, num_frames); + + const int in_channels = NumInputChannels(); + const int out_channels = NumOutputChannels(); + const int channels_to_process = std::min(in_channels, out_channels); + auto& state = *this->_fft_state; + const int direct_taps = state.direct_taps; + + for (int i = 0; i < num_frames; i++) + { + const long direct_offset = this->_input_buffer_offset - direct_taps + i + 1; + for (int ch = 0; ch < channels_to_process; ch++) + { + const int ring_index = (int)(state.sample_index % state.output_ring_size); + const float tail = state.channels[ch].output_ring[ring_index]; + state.channels[ch].output_ring[ring_index] = 0.0f; + + auto input_vec = Eigen::Map(&this->_input_buffers[ch][direct_offset], direct_taps); + output[ch][i] = this->_bias + this->_fft_direct_weight.dot(input_vec) + tail; + + if (state.num_partitions > 0) + { + auto& channel = state.channels[ch]; + channel.input_time[channel.input_pos] = (float)input[ch][i]; + channel.input_pos++; + if (channel.input_pos == state.block_size) + this->_run_fft_block(ch); + } + } + + for (int ch = channels_to_process; ch < out_channels; ch++) + output[ch][i] = (NAM_SAMPLE)0.0; + + state.sample_index++; + } + + nam::Buffer::_advance_input_buffer_(num_frames); +} + +void nam::Linear::_run_fft_block(const int channel_index) +{ + auto& state = *this->_fft_state; + auto& channel = state.channels[channel_index]; + + auto& current_spectrum = channel.input_spectra[channel.spectrum_write_index]; + state.fft.fwd(current_spectrum.data(), channel.input_time.data(), state.fft_size); + + std::fill(state.accumulator.begin(), state.accumulator.end(), LinearFFTState::Complex{}); + for (int partition = 0; partition < state.num_partitions; partition++) + { + int input_spectrum_index = channel.spectrum_write_index - partition; + if (input_spectrum_index < 0) + input_spectrum_index += state.num_partitions; + const auto& input_spectrum = channel.input_spectra[input_spectrum_index]; + const auto& kernel_spectrum = state.kernel_spectra[partition]; + for (int bin = 0; bin < state.fft_size; bin++) + state.accumulator[bin] += input_spectrum[bin] * kernel_spectrum[bin]; + } + + state.fft.inv(state.ifft_time.data(), state.accumulator.data(), state.fft_size); + + const long long block_start = state.sample_index - state.block_size + 1; + const long long output_start = block_start + state.direct_taps; + auto& output_ring = channel.output_ring; + for (int i = 0; i < state.fft_size - 1; i++) + { + const int ring_index = (int)((output_start + i) % state.output_ring_size); + output_ring[ring_index] += state.ifft_time[i]; + } + + std::fill(channel.input_time.begin(), channel.input_time.begin() + state.block_size, 0.0f); + channel.input_pos = 0; + channel.spectrum_write_index++; + if (channel.spectrum_write_index == state.num_partitions) + channel.spectrum_write_index = 0; +} + +nam::LinearImplementation nam::linear::parse_implementation(const std::string& implementation) +{ + std::string normalized = implementation; + std::transform( + normalized.begin(), normalized.end(), normalized.begin(), [](unsigned char c) { return (char)std::tolower(c); }); + + if (normalized == "auto") + return LinearImplementation::Auto; + if (normalized == "direct" || normalized == "legacy" || normalized == "old") + return LinearImplementation::Direct; + if (normalized == "fft" || normalized == "partitioned_fft" || normalized == "partitioned-fft") + return LinearImplementation::FFT; + throw std::runtime_error("Unsupported Linear implementation: " + implementation); +} + +std::string nam::linear::implementation_to_string(const LinearImplementation implementation) +{ + switch (implementation) + { + case LinearImplementation::Auto: return "auto"; + case LinearImplementation::Direct: return "direct"; + case LinearImplementation::FFT: return "fft"; + } + throw std::runtime_error("Unsupported Linear implementation enum"); +} + nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config) { LinearConfig c; @@ -62,12 +311,14 @@ nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& c // Default to 1 channel in/out for backward compatibility c.in_channels = config.value("in_channels", 1); c.out_channels = config.value("out_channels", 1); + c.implementation = parse_implementation(config.value("implementation", "auto")); return c; } std::unique_ptr nam::linear::LinearConfig::create(std::vector weights, double sampleRate) { - return std::make_unique(in_channels, out_channels, receptive_field, bias, weights, sampleRate); + return std::make_unique( + in_channels, out_channels, receptive_field, bias, weights, sampleRate, implementation); } std::unique_ptr nam::linear::create_config(const nlohmann::json& config, double sampleRate) diff --git a/NAM/linear.h b/NAM/linear.h index dbe9454..d3559bc 100644 --- a/NAM/linear.h +++ b/NAM/linear.h @@ -5,6 +5,16 @@ namespace nam { +struct LinearFFTState; + +/// \brief Selects the convolution engine used by Linear models. +enum class LinearImplementation +{ + Auto, ///< Choose direct or FFT convolution from the impulse-response length. + Direct, ///< Legacy per-sample direct convolution. + FFT ///< Zero-latency partitioned FFT convolution. +}; + /// \brief Basic linear model /// /// Implements a simple linear convolution, (i.e. an impulse response). @@ -18,8 +28,12 @@ class Linear : public Buffer /// \param _bias Whether to use bias /// \param weights Model weights (impulse response coefficients) /// \param expected_sample_rate Expected sample rate in Hz (-1.0 if unknown) + /// \param implementation Convolution implementation to use Linear(const int in_channels, const int out_channels, const int receptive_field, const bool _bias, - const std::vector& weights, const double expected_sample_rate = -1.0); + const std::vector& weights, const double expected_sample_rate = -1.0, + const LinearImplementation implementation = LinearImplementation::Auto); + + ~Linear() override; /// \brief Process audio frames /// \param input Input audio buffers @@ -27,9 +41,28 @@ class Linear : public Buffer /// \param num_frames Number of frames to process void process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames) override; + LinearImplementation GetRequestedImplementation() const { return _requested_implementation; } + LinearImplementation GetActiveImplementation() const { return _active_implementation; } + +protected: + void SetMaxBufferSize(const int maxBufferSize) override; + protected: Eigen::VectorXf _weight; + Eigen::VectorXf _fft_direct_weight; float _bias; + +private: + std::vector _impulse_response; + LinearImplementation _requested_implementation; + LinearImplementation _active_implementation; + std::unique_ptr _fft_state; + + void _configure_implementation(); + void _configure_fft_state(); + void _process_direct(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames); + void _process_fft(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num_frames); + void _run_fft_block(const int channel); }; namespace linear @@ -42,10 +75,17 @@ struct LinearConfig : public ModelConfig bool bias; int in_channels; int out_channels; + LinearImplementation implementation = LinearImplementation::Auto; std::unique_ptr create(std::vector weights, double sampleRate) override; }; +/// \brief Parse a Linear implementation string. +LinearImplementation parse_implementation(const std::string& implementation); + +/// \brief String name for a Linear implementation. +std::string implementation_to_string(const LinearImplementation implementation); + /// \brief Parse Linear configuration from JSON /// \param config JSON configuration object /// \return LinearConfig diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 367ba3d..0f9d50a 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -26,6 +26,7 @@ #include "test/test_wavenet_gating_compatibility.cpp" #include "test/test_blending_detailed.cpp" #include "test/test_input_buffer_verification.cpp" +#include "test/test_linear.cpp" #include "test/test_lstm.cpp" #include "test/test_wavenet_configurable_gating.cpp" #include "test/test_noncontiguous_blocks.cpp" @@ -84,6 +85,15 @@ int main() test_dsp::test_set_input_level(); test_dsp::test_set_output_level(); + test_linear::test_direct_known_values(); + test_linear::test_fft_matches_direct_irregular_chunks(); + test_linear::test_auto_selection(); + test_linear::test_parse_implementation(); + test_linear::test_direct_process_realtime_safe(); + test_linear::test_fft_process_realtime_safe(); + test_linear::test_auto_direct_process_realtime_safe(); + test_linear::test_auto_fft_process_realtime_safe(); + test_ring_buffer::test_construct(); test_ring_buffer::test_reset(); test_ring_buffer::test_reset_with_receptive_field(); diff --git a/tools/test/test_linear.cpp b/tools/test/test_linear.cpp new file mode 100644 index 0000000..ada26e7 --- /dev/null +++ b/tools/test/test_linear.cpp @@ -0,0 +1,197 @@ +// Tests for Linear DSP models + +#include "NAM/dsp.h" + +#include +#include +#include +#include +#include + +#include "allocation_tracking.h" + +namespace test_linear +{ +namespace +{ + +std::vector process_model(nam::Linear& model, const std::vector& input, + const std::vector& chunk_sizes) +{ + std::vector output(input.size(), (NAM_SAMPLE)0.0); + NAM_SAMPLE* input_ptrs[1]; + NAM_SAMPLE* output_ptrs[1]; + + size_t offset = 0; + size_t chunk_index = 0; + while (offset < input.size()) + { + const int requested = chunk_sizes[chunk_index % chunk_sizes.size()]; + const int count = std::min(requested, (int)(input.size() - offset)); + input_ptrs[0] = const_cast(&input[offset]); + output_ptrs[0] = &output[offset]; + model.process(input_ptrs, output_ptrs, count); + offset += count; + chunk_index++; + } + return output; +} + +std::vector make_input(const int num_samples) +{ + std::vector input(num_samples); + for (int i = 0; i < num_samples; i++) + input[i] = (NAM_SAMPLE)(0.2 * std::sin(0.013 * i) + 0.05 * std::cos(0.071 * i)); + return input; +} + +std::vector make_weights(const int receptive_field, const bool bias) +{ + std::vector weights; + weights.reserve(receptive_field + (bias ? 1 : 0)); + for (int i = 0; i < receptive_field; i++) + weights.push_back((float)(std::exp(-0.001 * i) * std::sin(0.037 * (i + 1)) * 0.01)); + if (bias) + weights.push_back(0.03125f); + return weights; +} + +void assert_near(const NAM_SAMPLE actual, const NAM_SAMPLE expected, const NAM_SAMPLE tolerance) +{ + assert(std::abs(actual - expected) <= tolerance); +} + +void assert_process_realtime_safe(const int receptive_field, const nam::LinearImplementation requested_implementation, + const nam::LinearImplementation expected_active_implementation, const char* test_name) +{ + const int max_buffer_size = 512; + const auto weights = make_weights(receptive_field, true); + nam::Linear model(1, 1, receptive_field, true, weights, 48000.0, requested_implementation); + model.Reset(48000.0, max_buffer_size); + assert(model.GetActiveImplementation() == expected_active_implementation); + + std::vector input(max_buffer_size); + std::vector output(max_buffer_size); + for (int i = 0; i < max_buffer_size; i++) + input[i] = (NAM_SAMPLE)(0.1 * std::sin(0.021 * i) + 0.03 * std::cos(0.017 * i)); + + NAM_SAMPLE* input_ptrs[1] = {input.data()}; + NAM_SAMPLE* output_ptrs[1] = {output.data()}; + + model.process(input_ptrs, output_ptrs, max_buffer_size); + + const int block_sizes[] = {1, 7, 32, 64, 128, 256, 3, 511, 512}; + allocation_tracking::run_allocation_test_no_allocations( + nullptr, + [&]() { + for (int pass = 0; pass < 8; pass++) + { + for (const int block_size : block_sizes) + model.process(input_ptrs, output_ptrs, block_size); + } + }, + nullptr, test_name); + + for (int i = 0; i < max_buffer_size; i++) + assert(std::isfinite(output[i])); +} + +} // namespace + +void test_direct_known_values() +{ + const std::vector weights{0.5f, -0.25f, 0.125f}; + nam::Linear model(1, 1, 3, false, weights, 48000.0, nam::LinearImplementation::Direct); + + const std::vector input{(NAM_SAMPLE)1.0, (NAM_SAMPLE)2.0, (NAM_SAMPLE)3.0, (NAM_SAMPLE)4.0}; + const auto output = process_model(model, input, {4}); + + assert_near(output[0], 0.5, 1.0e-7); + assert_near(output[1], 0.75, 1.0e-7); + assert_near(output[2], 1.125, 1.0e-7); + assert_near(output[3], 1.5, 1.0e-7); +} + +void test_fft_matches_direct_irregular_chunks() +{ + const int receptive_field = 1536; + const bool bias = true; + const auto weights = make_weights(receptive_field, bias); + const auto input = make_input(4096); + + nam::Linear direct(1, 1, receptive_field, bias, weights, 48000.0, nam::LinearImplementation::Direct); + nam::Linear fft(1, 1, receptive_field, bias, weights, 48000.0, nam::LinearImplementation::FFT); + + const std::vector chunks{1, 17, 64, 255, 3, 512, 31}; + const auto direct_output = process_model(direct, input, chunks); + const auto fft_output = process_model(fft, input, chunks); + + NAM_SAMPLE max_abs_diff = 0.0; + for (size_t i = 0; i < input.size(); i++) + max_abs_diff = std::max(max_abs_diff, std::abs(direct_output[i] - fft_output[i])); + + assert(max_abs_diff < 5.0e-5); +} + +void test_auto_selection() +{ + const auto short_weights = make_weights(128, false); + nam::Linear short_model(1, 1, 128, false, short_weights, 48000.0); + assert(short_model.GetRequestedImplementation() == nam::LinearImplementation::Auto); + assert(short_model.GetActiveImplementation() == nam::LinearImplementation::Direct); + + const auto cutoff_weights = make_weights(256, false); + nam::Linear cutoff_model(1, 1, 256, false, cutoff_weights, 48000.0); + assert(cutoff_model.GetRequestedImplementation() == nam::LinearImplementation::Auto); + assert(cutoff_model.GetActiveImplementation() == nam::LinearImplementation::Direct); + + const auto fft_weights = make_weights(512, false); + nam::Linear fft_model(1, 1, 512, false, fft_weights, 48000.0); + assert(fft_model.GetRequestedImplementation() == nam::LinearImplementation::Auto); + assert(fft_model.GetActiveImplementation() == nam::LinearImplementation::FFT); +} + +void test_parse_implementation() +{ + assert(nam::linear::parse_implementation("auto") == nam::LinearImplementation::Auto); + assert(nam::linear::parse_implementation("legacy") == nam::LinearImplementation::Direct); + assert(nam::linear::parse_implementation("partitioned-fft") == nam::LinearImplementation::FFT); + assert(nam::linear::implementation_to_string(nam::LinearImplementation::Direct) == "direct"); + + bool threw = false; + try + { + nam::linear::parse_implementation("not-a-real-implementation"); + } + catch (const std::runtime_error&) + { + threw = true; + } + assert(threw); +} + +void test_direct_process_realtime_safe() +{ + assert_process_realtime_safe( + 512, nam::LinearImplementation::Direct, nam::LinearImplementation::Direct, "Linear direct process real-time safe"); +} + +void test_fft_process_realtime_safe() +{ + assert_process_realtime_safe( + 4096, nam::LinearImplementation::FFT, nam::LinearImplementation::FFT, "Linear FFT process real-time safe"); +} + +void test_auto_direct_process_realtime_safe() +{ + assert_process_realtime_safe(128, nam::LinearImplementation::Auto, nam::LinearImplementation::Direct, + "Linear auto direct process real-time safe"); +} + +void test_auto_fft_process_realtime_safe() +{ + assert_process_realtime_safe( + 4096, nam::LinearImplementation::Auto, nam::LinearImplementation::FFT, "Linear auto FFT process real-time safe"); +} + +} // namespace test_linear