From 1fe1f7781fd301089293c901234d7dadcdd957b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Fri, 6 Feb 2026 09:45:14 -0800 Subject: [PATCH 1/6] Add inline GEMM optimizations and general performance improvements Hand-optimized GEMM kernels for small matrices common in NAM models, gated by #ifdef NAM_USE_INLINE_GEMM with Eigen fallback. Includes: - Specialized Conv1D kernels: fused 4x4 and 2x2 kernel_size=3, plus fully-unrolled paths for 2x2 through 8x8 channel configurations - Conv1x1 inline specializations for all common size combinations - FiLM inline path with 4-element loop unrolling - GatingActivation/BlendingActivation inline paths - Branchless hardswish, 4-element loop unrolling for all activations - SiLU added to LUT enable/disable - Ring buffer refactored to Eigen block operations - memcpy replacements for pure copy operations in wavenet - Optimized single-channel output path in WaveNet::process - Buffer size benchmark tool (benchmodel_bufsize) Co-Authored-By: Claude Opus 4.6 --- NAM/activations.cpp | 14 +- NAM/activations.h | 114 +++++++-- NAM/conv1d.cpp | 472 ++++++++++++++++++++++++++++++++++- NAM/dsp.cpp | 241 ++++++++++++++++++ NAM/film.h | 58 +++++ NAM/gating_activations.h | 80 +++++- NAM/ring_buffer.cpp | 20 +- NAM/wavenet.cpp | 147 ++++++++++- NAM/wavenet.h | 12 +- tools/CMakeLists.txt | 1 + tools/benchmodel_bufsize.cpp | 96 +++++++ 11 files changed, 1194 insertions(+), 61 deletions(-) create mode 100644 tools/benchmodel_bufsize.cpp diff --git a/NAM/activations.cpp b/NAM/activations.cpp index a476520c..8d37a19a 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -37,6 +37,7 @@ std::unordered_map nam::activati nam::activations::Activation::Ptr tanh_bak = nullptr; nam::activations::Activation::Ptr sigmoid_bak = nullptr; +nam::activations::Activation::Ptr silu_bak = nullptr; nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const std::string name) { @@ -197,9 +198,14 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m fn = sigmoid; sigmoid_bak = _activations["Sigmoid"]; } + else if (function_name == "SiLU") + { + fn = swish; + silu_bak = _activations["SiLU"]; + } else { - throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid"); + throw std::runtime_error("Tried to enable LUT for a function other than Tanh, Sigmoid, or SiLU"); } _activations[function_name] = std::make_shared(min, max, n_points, fn); } @@ -214,8 +220,12 @@ void nam::activations::Activation::disable_lut(std::string function_name) { _activations["Sigmoid"] = sigmoid_bak; } + else if (function_name == "SiLU") + { + _activations["SiLU"] = silu_bak; + } else { - throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid"); + throw std::runtime_error("Tried to disable LUT for a function other than Tanh, Sigmoid, or SiLU"); } } diff --git a/NAM/activations.h b/NAM/activations.h index 68d5025c..9b0c47ca 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -117,18 +117,12 @@ inline float swish(float x) inline float hardswish(float x) { - if (x <= -3.0) - { - return 0; - } - else if (x >= 3.0) - { - return x; - } - else - { - return x * (x + 3.0) / 6.0; - } + // Branchless implementation using clamp + // hardswish(x) = x * relu6(x + 3) / 6 + // = x * clamp(x + 3, 0, 6) / 6 + const float t = x + 3.0f; + const float clamped = t < 0.0f ? 0.0f : (t > 6.0f ? 6.0f : t); + return x * clamped * (1.0f / 6.0f); } inline float softsign(float x) @@ -241,9 +235,23 @@ class ActivationReLU : public Activation public: void apply(float* data, long size) override { - for (long pos = 0; pos < size; pos++) + // Optimized ReLU with loop unrolling + long pos = 0; + // Process 4 elements at a time + for (; pos + 3 < size; pos += 4) + { + // Branchless ReLU using conditional + const float v0 = data[pos], v1 = data[pos + 1]; + const float v2 = data[pos + 2], v3 = data[pos + 3]; + data[pos] = v0 > 0.0f ? v0 : 0.0f; + data[pos + 1] = v1 > 0.0f ? v1 : 0.0f; + data[pos + 2] = v2 > 0.0f ? v2 : 0.0f; + data[pos + 3] = v3 > 0.0f ? v3 : 0.0f; + } + // Handle remainder + for (; pos < size; pos++) { - data[pos] = relu(data[pos]); + data[pos] = data[pos] > 0.0f ? data[pos] : 0.0f; } } }; @@ -308,7 +316,20 @@ class ActivationSigmoid : public Activation public: void apply(float* data, long size) override { - for (long pos = 0; pos < size; pos++) + long pos = 0; + // Process 4 elements at a time + for (; pos + 3 < size; pos += 4) + { + const float x0 = data[pos], x1 = data[pos + 1]; + const float x2 = data[pos + 2], x3 = data[pos + 3]; + + data[pos] = 1.0f / (1.0f + expf(-x0)); + data[pos + 1] = 1.0f / (1.0f + expf(-x1)); + data[pos + 2] = 1.0f / (1.0f + expf(-x2)); + data[pos + 3] = 1.0f / (1.0f + expf(-x3)); + } + // Handle remainder + for (; pos < size; pos++) { data[pos] = sigmoid(data[pos]); } @@ -320,7 +341,25 @@ class ActivationSwish : public Activation public: void apply(float* data, long size) override { - for (long pos = 0; pos < size; pos++) + long pos = 0; + // Process 4 elements at a time: swish(x) = x * sigmoid(x) = x / (1 + exp(-x)) + for (; pos + 3 < size; pos += 4) + { + const float x0 = data[pos], x1 = data[pos + 1]; + const float x2 = data[pos + 2], x3 = data[pos + 3]; + + const float s0 = 1.0f / (1.0f + expf(-x0)); + const float s1 = 1.0f / (1.0f + expf(-x1)); + const float s2 = 1.0f / (1.0f + expf(-x2)); + const float s3 = 1.0f / (1.0f + expf(-x3)); + + data[pos] = x0 * s0; + data[pos + 1] = x1 * s1; + data[pos + 2] = x2 * s2; + data[pos + 3] = x3 * s3; + } + // Handle remainder + for (; pos < size; pos++) { data[pos] = swish(data[pos]); } @@ -332,7 +371,29 @@ class ActivationHardSwish : public Activation public: void apply(float* data, long size) override { - for (long pos = 0; pos < size; pos++) + const float inv6 = 1.0f / 6.0f; + long pos = 0; + // Process 4 elements at a time + for (; pos + 3 < size; pos += 4) + { + const float x0 = data[pos], x1 = data[pos + 1]; + const float x2 = data[pos + 2], x3 = data[pos + 3]; + + const float t0 = x0 + 3.0f, t1 = x1 + 3.0f; + const float t2 = x2 + 3.0f, t3 = x3 + 3.0f; + + const float c0 = t0 < 0.0f ? 0.0f : (t0 > 6.0f ? 6.0f : t0); + const float c1 = t1 < 0.0f ? 0.0f : (t1 > 6.0f ? 6.0f : t1); + const float c2 = t2 < 0.0f ? 0.0f : (t2 > 6.0f ? 6.0f : t2); + const float c3 = t3 < 0.0f ? 0.0f : (t3 > 6.0f ? 6.0f : t3); + + data[pos] = x0 * c0 * inv6; + data[pos + 1] = x1 * c1 * inv6; + data[pos + 2] = x2 * c2 * inv6; + data[pos + 3] = x3 * c3 * inv6; + } + // Handle remainder + for (; pos < size; pos++) { data[pos] = hardswish(data[pos]); } @@ -344,7 +405,20 @@ class ActivationSoftsign : public Activation public: void apply(float* data, long size) override { - for (long pos = 0; pos < size; pos++) + long pos = 0; + // Process 4 elements at a time + for (; pos + 3 < size; pos += 4) + { + const float x0 = data[pos], x1 = data[pos + 1]; + const float x2 = data[pos + 2], x3 = data[pos + 3]; + + data[pos] = x0 / (1.0f + fabsf(x0)); + data[pos + 1] = x1 / (1.0f + fabsf(x1)); + data[pos + 2] = x2 / (1.0f + fabsf(x2)); + data[pos + 3] = x3 / (1.0f + fabsf(x3)); + } + // Handle remainder + for (; pos < size; pos++) { data[pos] = softsign(data[pos]); } @@ -373,8 +447,8 @@ class FastLUTActivation : public Activation // Fast lookup with linear interpolation inline float lookup(float x) const { - // Clamp input to range - x = std::clamp(x, min_x_, max_x_); + // Clamp input to range (inline to avoid header dependency) + x = x < min_x_ ? min_x_ : (x > max_x_ ? max_x_ : x); // Calculate float index float f_idx = (x - min_x_) * inv_step_; diff --git a/NAM/conv1d.cpp b/NAM/conv1d.cpp index 9bbbc020..e14c11ca 100644 --- a/NAM/conv1d.cpp +++ b/NAM/conv1d.cpp @@ -1,4 +1,5 @@ #include "conv1d.h" +#include #include namespace nam @@ -146,8 +147,8 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) // Write input to ring buffer _input_buffer.Write(input, num_frames); - // Zero output before processing - _output.leftCols(num_frames).setZero(); + // Note: setZero is deferred - only called for paths that need it (those using +=) + // Fused kernel paths use direct assignment (=) and skip setZero // Process from ring buffer with dilation lookback // After Write(), data is at positions [_write_pos, _write_pos+num_frames-1] @@ -156,10 +157,78 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) if (this->_is_depthwise) { + // Depthwise convolution uses += accumulation, so needs setZero + _output.leftCols(num_frames).setZero(); + // Depthwise convolution: use efficient element-wise multiplication // Each channel is processed independently with a single weight per kernel tap. // output[c, t] = sum_k(weight[k, c] * input[c, t - k*dilation]) const size_t kernel_size = this->_depthwise_weight.size(); +#ifdef NAM_USE_INLINE_GEMM + const int channels = this->_channels; + float* __restrict__ output_ptr = _output.data(); + + for (size_t k = 0; k < kernel_size; k++) + { + const long offset = this->_dilation * (k + 1 - (long)kernel_size); + const long lookback = -offset; + auto input_block = _input_buffer.Read(num_frames, lookback); + const float* __restrict__ input_ptr = input_block.data(); + const float* __restrict__ weight_ptr = this->_depthwise_weight[k].data(); + + // Specialized paths for common channel counts + if (channels == 4) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + output_ptr[off] += w0 * input_ptr[off]; + output_ptr[off + 1] += w1 * input_ptr[off + 1]; + output_ptr[off + 2] += w2 * input_ptr[off + 2]; + output_ptr[off + 3] += w3 * input_ptr[off + 3]; + } + } + else if (channels == 8) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2], w3 = weight_ptr[3]; + const float w4 = weight_ptr[4], w5 = weight_ptr[5], w6 = weight_ptr[6], w7 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 8; + output_ptr[off] += w0 * input_ptr[off]; + output_ptr[off + 1] += w1 * input_ptr[off + 1]; + output_ptr[off + 2] += w2 * input_ptr[off + 2]; + output_ptr[off + 3] += w3 * input_ptr[off + 3]; + output_ptr[off + 4] += w4 * input_ptr[off + 4]; + output_ptr[off + 5] += w5 * input_ptr[off + 5]; + output_ptr[off + 6] += w6 * input_ptr[off + 6]; + output_ptr[off + 7] += w7 * input_ptr[off + 7]; + } + } + else + { + // General depthwise path with loop unrolling + for (int f = 0; f < num_frames; f++) + { + const int off = f * channels; + int c = 0; + for (; c + 3 < channels; c += 4) + { + output_ptr[off + c] += weight_ptr[c] * input_ptr[off + c]; + output_ptr[off + c + 1] += weight_ptr[c + 1] * input_ptr[off + c + 1]; + output_ptr[off + c + 2] += weight_ptr[c + 2] * input_ptr[off + c + 2]; + output_ptr[off + c + 3] += weight_ptr[c + 3] * input_ptr[off + c + 3]; + } + for (; c < channels; c++) + { + output_ptr[off + c] += weight_ptr[c] * input_ptr[off + c]; + } + } + } + } +#else for (size_t k = 0; k < kernel_size; k++) { const long offset = this->_dilation * (k + 1 - (long)kernel_size); @@ -169,9 +238,326 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) _output.leftCols(num_frames).noalias() += this->_depthwise_weight[k].asDiagonal() * input_block.leftCols(num_frames); } +#endif } else { +#ifdef NAM_USE_INLINE_GEMM + // Hand-optimized inline GEMM for small matrices + // For output(out_ch, frames) += weight(out_ch, in_ch) * input(in_ch, frames) + // + // Column-major layout means: + // output(o, f) is at output_ptr[f * out_ch + o] + // weight(o, i) is at weight_ptr[i * out_ch + o] + // input(i, f) is at input_ptr[f * in_ch + i] + const int out_ch = (int)get_out_channels(); + const int in_ch = (int)get_in_channels(); + const size_t kernel_size = this->_weight.size(); + const size_t weight_matrix_size = out_ch * in_ch; + + // Fused kernel optimization for kernel_size=3 + // Instead of 3 separate passes over output, fuse into single pass + if (kernel_size == 3 && out_ch == 4 && in_ch == 4) + { + // Fused 4x4 kernel_size=3: read all 3 input blocks and compute in one pass + const long dil = this->_dilation; + auto in0 = _input_buffer.Read(num_frames, 2 * dil); // oldest (k=0) + auto in1 = _input_buffer.Read(num_frames, dil); // middle (k=1) + auto in2 = _input_buffer.Read(num_frames, 0); // newest (k=2) + + const float* __restrict__ in0_ptr = in0.data(); + const float* __restrict__ in1_ptr = in1.data(); + const float* __restrict__ in2_ptr = in2.data(); + float* __restrict__ output_ptr = _output.data(); + + // Get weight pointers for all 3 taps + const size_t wsize = 16; // 4x4 + const float* __restrict__ w0 = this->_weight[0].data(); + const float* __restrict__ w1 = this->_weight[1].data(); + const float* __restrict__ w2 = this->_weight[2].data(); + + // Cache all weights in registers (48 floats for 3 x 4x4 matrices) + const float w0_00 = w0[0], w0_10 = w0[1], w0_20 = w0[2], w0_30 = w0[3]; + const float w0_01 = w0[4], w0_11 = w0[5], w0_21 = w0[6], w0_31 = w0[7]; + const float w0_02 = w0[8], w0_12 = w0[9], w0_22 = w0[10], w0_32 = w0[11]; + const float w0_03 = w0[12], w0_13 = w0[13], w0_23 = w0[14], w0_33 = w0[15]; + + const float w1_00 = w1[0], w1_10 = w1[1], w1_20 = w1[2], w1_30 = w1[3]; + const float w1_01 = w1[4], w1_11 = w1[5], w1_21 = w1[6], w1_31 = w1[7]; + const float w1_02 = w1[8], w1_12 = w1[9], w1_22 = w1[10], w1_32 = w1[11]; + const float w1_03 = w1[12], w1_13 = w1[13], w1_23 = w1[14], w1_33 = w1[15]; + + const float w2_00 = w2[0], w2_10 = w2[1], w2_20 = w2[2], w2_30 = w2[3]; + const float w2_01 = w2[4], w2_11 = w2[5], w2_21 = w2[6], w2_31 = w2[7]; + const float w2_02 = w2[8], w2_12 = w2[9], w2_22 = w2[10], w2_32 = w2[11]; + const float w2_03 = w2[12], w2_13 = w2[13], w2_23 = w2[14], w2_33 = w2[15]; + + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + // Load inputs from all 3 taps + const float i0_0 = in0_ptr[off], i0_1 = in0_ptr[off + 1], i0_2 = in0_ptr[off + 2], i0_3 = in0_ptr[off + 3]; + const float i1_0 = in1_ptr[off], i1_1 = in1_ptr[off + 1], i1_2 = in1_ptr[off + 2], i1_3 = in1_ptr[off + 3]; + const float i2_0 = in2_ptr[off], i2_1 = in2_ptr[off + 1], i2_2 = in2_ptr[off + 2], i2_3 = in2_ptr[off + 3]; + + // Compute output = W0*in0 + W1*in1 + W2*in2 (fused, output was zeroed) + output_ptr[off] = (w0_00 * i0_0 + w0_01 * i0_1 + w0_02 * i0_2 + w0_03 * i0_3) + + (w1_00 * i1_0 + w1_01 * i1_1 + w1_02 * i1_2 + w1_03 * i1_3) + + (w2_00 * i2_0 + w2_01 * i2_1 + w2_02 * i2_2 + w2_03 * i2_3); + output_ptr[off + 1] = (w0_10 * i0_0 + w0_11 * i0_1 + w0_12 * i0_2 + w0_13 * i0_3) + + (w1_10 * i1_0 + w1_11 * i1_1 + w1_12 * i1_2 + w1_13 * i1_3) + + (w2_10 * i2_0 + w2_11 * i2_1 + w2_12 * i2_2 + w2_13 * i2_3); + output_ptr[off + 2] = (w0_20 * i0_0 + w0_21 * i0_1 + w0_22 * i0_2 + w0_23 * i0_3) + + (w1_20 * i1_0 + w1_21 * i1_1 + w1_22 * i1_2 + w1_23 * i1_3) + + (w2_20 * i2_0 + w2_21 * i2_1 + w2_22 * i2_2 + w2_23 * i2_3); + output_ptr[off + 3] = (w0_30 * i0_0 + w0_31 * i0_1 + w0_32 * i0_2 + w0_33 * i0_3) + + (w1_30 * i1_0 + w1_31 * i1_1 + w1_32 * i1_2 + w1_33 * i1_3) + + (w2_30 * i2_0 + w2_31 * i2_1 + w2_32 * i2_2 + w2_33 * i2_3); + } + } + else if (kernel_size == 3 && out_ch == 2 && in_ch == 2) + { + // Fused 2x2 kernel_size=3: read all 3 input blocks and compute in one pass + const long dil = this->_dilation; + auto in0 = _input_buffer.Read(num_frames, 2 * dil); + auto in1 = _input_buffer.Read(num_frames, dil); + auto in2 = _input_buffer.Read(num_frames, 0); + + const float* __restrict__ in0_ptr = in0.data(); + const float* __restrict__ in1_ptr = in1.data(); + const float* __restrict__ in2_ptr = in2.data(); + float* __restrict__ output_ptr = _output.data(); + + const float* __restrict__ w0 = this->_weight[0].data(); + const float* __restrict__ w1 = this->_weight[1].data(); + const float* __restrict__ w2 = this->_weight[2].data(); + + // Cache weights (12 floats total) + const float w0_00 = w0[0], w0_10 = w0[1], w0_01 = w0[2], w0_11 = w0[3]; + const float w1_00 = w1[0], w1_10 = w1[1], w1_01 = w1[2], w1_11 = w1[3]; + const float w2_00 = w2[0], w2_10 = w2[1], w2_01 = w2[2], w2_11 = w2[3]; + + for (int f = 0; f < num_frames; f++) + { + const int off = f * 2; + const float i0_0 = in0_ptr[off], i0_1 = in0_ptr[off + 1]; + const float i1_0 = in1_ptr[off], i1_1 = in1_ptr[off + 1]; + const float i2_0 = in2_ptr[off], i2_1 = in2_ptr[off + 1]; + + output_ptr[off] = (w0_00 * i0_0 + w0_01 * i0_1) + (w1_00 * i1_0 + w1_01 * i1_1) + (w2_00 * i2_0 + w2_01 * i2_1); + output_ptr[off + 1] = + (w0_10 * i0_0 + w0_11 * i0_1) + (w1_10 * i1_0 + w1_11 * i1_1) + (w2_10 * i2_0 + w2_11 * i2_1); + } + } + else + { + // General inline GEMM path uses += accumulation, so needs setZero + _output.leftCols(num_frames).setZero(); + + // General inline GEMM path for other configurations + for (size_t k = 0; k < kernel_size; k++) + { + const long offset = this->_dilation * (k + 1 - (long)kernel_size); + const long lookback = -offset; + auto input_block = _input_buffer.Read(num_frames, lookback); + + const float* __restrict__ input_ptr = input_block.data(); + const float* __restrict__ weight_ptr = this->_weight[k].data(); + float* __restrict__ output_ptr = _output.data(); + + // Specialized fully-unrolled paths for common small channel counts + // These avoid all loop overhead for the tiny matrices in NAM models + if (out_ch == 2 && in_ch == 2) + { + // 2x2 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 2]; + const float i1 = input_ptr[f * 2 + 1]; + output_ptr[f * 2] += w00 * i0 + w01 * i1; + output_ptr[f * 2 + 1] += w10 * i0 + w11 * i1; + } + } + else if (out_ch == 2 && in_ch == 4) + { + // 2x4 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + const float w02 = weight_ptr[4], w12 = weight_ptr[5]; + const float w03 = weight_ptr[6], w13 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 4]; + const float i1 = input_ptr[f * 4 + 1]; + const float i2 = input_ptr[f * 4 + 2]; + const float i3 = input_ptr[f * 4 + 3]; + output_ptr[f * 2] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 2 + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + } + } + else if (out_ch == 4 && in_ch == 1) + { + // 4x1 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + output_ptr[f * 4] += w0 * in_val; + output_ptr[f * 4 + 1] += w1 * in_val; + output_ptr[f * 4 + 2] += w2 * in_val; + output_ptr[f * 4 + 3] += w3 * in_val; + } + } + else if (out_ch == 4 && in_ch == 4) + { + // 4x4 fully unrolled - cache weights in registers + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; + for (int f = 0; f < num_frames; f++) + { + const int in_off = f * 4; + const int out_off = f * 4; + const float i0 = input_ptr[in_off]; + const float i1 = input_ptr[in_off + 1]; + const float i2 = input_ptr[in_off + 2]; + const float i3 = input_ptr[in_off + 3]; + output_ptr[out_off] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[out_off + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[out_off + 2] += w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[out_off + 3] += w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + } + } + else if (out_ch == 3 && in_ch == 1) + { + // 3x1 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + output_ptr[f * 3] += w0 * in_val; + output_ptr[f * 3 + 1] += w1 * in_val; + output_ptr[f * 3 + 2] += w2 * in_val; + } + } + else if (out_ch == 3 && in_ch == 3) + { + // 3x3 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2]; + const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5]; + const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 3; + const float i0 = input_ptr[off]; + const float i1 = input_ptr[off + 1]; + const float i2 = input_ptr[off + 2]; + output_ptr[off] += w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[off + 1] += w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[off + 2] += w20 * i0 + w21 * i1 + w22 * i2; + } + } + else if (out_ch == 4 && in_ch == 3) + { + // 4x3 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 3]; + const float i1 = input_ptr[f * 3 + 1]; + const float i2 = input_ptr[f * 3 + 2]; + output_ptr[f * 4] += w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[f * 4 + 1] += w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[f * 4 + 2] += w20 * i0 + w21 * i1 + w22 * i2; + output_ptr[f * 4 + 3] += w30 * i0 + w31 * i1 + w32 * i2; + } + } + else if (out_ch == 3 && in_ch == 4) + { + // 3x4 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2]; + const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5]; + const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; + const float w03 = weight_ptr[9], w13 = weight_ptr[10], w23 = weight_ptr[11]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 4]; + const float i1 = input_ptr[f * 4 + 1]; + const float i2 = input_ptr[f * 4 + 2]; + const float i3 = input_ptr[f * 4 + 3]; + output_ptr[f * 3] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 3 + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[f * 3 + 2] += w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + } + } + else if (out_ch == 6 && in_ch == 1) + { + // 6x1 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2]; + const float w3 = weight_ptr[3], w4 = weight_ptr[4], w5 = weight_ptr[5]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + const int off = f * 6; + output_ptr[off] += w0 * in_val; + output_ptr[off + 1] += w1 * in_val; + output_ptr[off + 2] += w2 * in_val; + output_ptr[off + 3] += w3 * in_val; + output_ptr[off + 4] += w4 * in_val; + output_ptr[off + 5] += w5 * in_val; + } + } + else if (out_ch == 6 && in_ch == 6) + { + // 6x6 - unroll weights, loop over frames + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 6; + float* __restrict__ out_col = output_ptr + f * 6; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; + const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; + for (int o = 0; o < 6; o++) + { + out_col[o] += weight_ptr[o] * i0 + weight_ptr[6 + o] * i1 + weight_ptr[12 + o] * i2 + + weight_ptr[18 + o] * i3 + weight_ptr[24 + o] * i4 + weight_ptr[30 + o] * i5; + } + } + } + else if (out_ch == 8 && in_ch == 8) + { + // 8x8 - unroll weights, loop over frames + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 8; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; + for (int o = 0; o < 8; o++) + { + out_col[o] += weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + + weight_ptr[24 + o] * i3 + weight_ptr[32 + o] * i4 + weight_ptr[40 + o] * i5 + + weight_ptr[48 + o] * i6 + weight_ptr[56 + o] * i7; + } + } + } + else + { + // Fall back to Eigen for larger matrices where it's more efficient + _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; + } + } + } // end else (general GEMM path) +#else + // Eigen fallback uses += accumulation, so needs setZero + _output.leftCols(num_frames).setZero(); + + // Eigen fallback for non-ARM platforms // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal), // so we can use a single GEMM for all cases. A more advanced implementation could store // compact per-group weight matrices and loop over groups, but at typical model sizes @@ -184,12 +570,94 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) auto input_block = _input_buffer.Read(num_frames, lookback); _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; } +#endif } // Add bias if present if (this->_bias.size() > 0) { +#ifdef NAM_USE_INLINE_GEMM + // Inline bias addition for small channel counts + const int out_ch = (int)get_out_channels(); + float* __restrict__ output_ptr = _output.data(); + const float* __restrict__ bias_ptr = this->_bias.data(); + + if (out_ch == 2) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f * 2] += b0; + output_ptr[f * 2 + 1] += b1; + } + } + else if (out_ch == 3) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f * 3] += b0; + output_ptr[f * 3 + 1] += b1; + output_ptr[f * 3 + 2] += b2; + } + } + else if (out_ch == 4) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + const float b2 = bias_ptr[2], b3 = bias_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f * 4] += b0; + output_ptr[f * 4 + 1] += b1; + output_ptr[f * 4 + 2] += b2; + output_ptr[f * 4 + 3] += b3; + } + } + else if (out_ch == 6) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2]; + const float b3 = bias_ptr[3], b4 = bias_ptr[4], b5 = bias_ptr[5]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 6; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + output_ptr[off + 2] += b2; + output_ptr[off + 3] += b3; + output_ptr[off + 4] += b4; + output_ptr[off + 5] += b5; + } + } + else if (out_ch == 8) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2], b3 = bias_ptr[3]; + const float b4 = bias_ptr[4], b5 = bias_ptr[5], b6 = bias_ptr[6], b7 = bias_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 8; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + output_ptr[off + 2] += b2; + output_ptr[off + 3] += b3; + output_ptr[off + 4] += b4; + output_ptr[off + 5] += b5; + output_ptr[off + 6] += b6; + output_ptr[off + 7] += b7; + } + } + else + { + for (int f = 0; f < num_frames; f++) + { + for (int o = 0; o < out_ch; o++) + { + output_ptr[f * out_ch + o] += bias_ptr[o]; + } + } + } +#else _output.leftCols(num_frames).colwise() += this->_bias; +#endif } // Advance ring buffer write pointer after processing diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 05dab09d..e42f8e72 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -1,5 +1,6 @@ #include // std::max_element #include // pow, tanh, expf +#include #include #include #include @@ -453,10 +454,250 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons } else { +#ifdef NAM_USE_INLINE_GEMM + // Hand-optimized GEMM for small matrices (1x1 convolution) + // output(out_ch, frames) = weight(out_ch, in_ch) * input(in_ch, frames) + const int out_ch = (int)get_out_channels(); + const int in_ch = (int)get_in_channels(); + const float* __restrict__ input_ptr = input.data(); + const float* __restrict__ weight_ptr = this->_weight.data(); + float* __restrict__ output_ptr = _output.data(); + + // Specialized paths for common small sizes + if (out_ch == 2 && in_ch == 1) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + output_ptr[f * 2] = w0 * in_val; + output_ptr[f * 2 + 1] = w1 * in_val; + } + } + else if (out_ch == 4 && in_ch == 1) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + output_ptr[f * 4] = w0 * in_val; + output_ptr[f * 4 + 1] = w1 * in_val; + output_ptr[f * 4 + 2] = w2 * in_val; + output_ptr[f * 4 + 3] = w3 * in_val; + } + } + else if (out_ch == 1 && in_ch == 2) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f] = w0 * input_ptr[f * 2] + w1 * input_ptr[f * 2 + 1]; + } + } + else if (out_ch == 2 && in_ch == 2) + { + // 2x2 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 2; + const float i0 = input_ptr[off]; + const float i1 = input_ptr[off + 1]; + output_ptr[off] = w00 * i0 + w01 * i1; + output_ptr[off + 1] = w10 * i0 + w11 * i1; + } + } + else if (out_ch == 2 && in_ch == 4) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + const float w02 = weight_ptr[4], w12 = weight_ptr[5]; + const float w03 = weight_ptr[6], w13 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 4]; + const float i1 = input_ptr[f * 4 + 1]; + const float i2 = input_ptr[f * 4 + 2]; + const float i3 = input_ptr[f * 4 + 3]; + output_ptr[f * 2] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 2 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + } + } + else if (out_ch == 1 && in_ch == 4) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + output_ptr[f] = w0 * input_ptr[off] + w1 * input_ptr[off + 1] + + w2 * input_ptr[off + 2] + w3 * input_ptr[off + 3]; + } + } + else if (out_ch == 4 && in_ch == 2) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 2]; + const float i1 = input_ptr[f * 2 + 1]; + output_ptr[f * 4] = w00 * i0 + w01 * i1; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1; + } + } + else if (out_ch == 3 && in_ch == 3) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2]; + const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5]; + const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 3; + const float i0 = input_ptr[off]; + const float i1 = input_ptr[off + 1]; + const float i2 = input_ptr[off + 2]; + output_ptr[off] = w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[off + 1] = w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[off + 2] = w20 * i0 + w21 * i1 + w22 * i2; + } + } + else if (out_ch == 4 && in_ch == 4) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + const float i0 = input_ptr[off]; + const float i1 = input_ptr[off + 1]; + const float i2 = input_ptr[off + 2]; + const float i3 = input_ptr[off + 3]; + output_ptr[off] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[off + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[off + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[off + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + } + } + else if (out_ch == 6 && in_ch == 6) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 6; + float* __restrict__ out_col = output_ptr + f * 6; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; + const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; + for (int o = 0; o < 6; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[6 + o] * i1 + weight_ptr[12 + o] * i2 + + weight_ptr[18 + o] * i3 + weight_ptr[24 + o] * i4 + weight_ptr[30 + o] * i5; + } + } + } + else if (out_ch == 8 && in_ch == 8) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 8; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; + for (int o = 0; o < 8; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + weight_ptr[24 + o] * i3 + + weight_ptr[32 + o] * i4 + weight_ptr[40 + o] * i5 + weight_ptr[48 + o] * i6 + weight_ptr[56 + o] * i7; + } + } + } + else if (out_ch == 4 && in_ch == 8) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 8; + float* __restrict__ out_col = output_ptr + f * 4; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; + for (int o = 0; o < 4; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[4 + o] * i1 + weight_ptr[8 + o] * i2 + weight_ptr[12 + o] * i3 + + weight_ptr[16 + o] * i4 + weight_ptr[20 + o] * i5 + weight_ptr[24 + o] * i6 + weight_ptr[28 + o] * i7; + } + } + } + else if (out_ch == 8 && in_ch == 4) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 4; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + for (int o = 0; o < 8; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + weight_ptr[24 + o] * i3; + } + } + } + else + { + // Fall back to Eigen for larger matrices where it's more efficient + _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); + } +#else // Single GEMM for all cases - block-diagonal zero structure handles grouping _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); +#endif } if (this->_do_bias) + { +#ifdef NAM_USE_INLINE_GEMM + const int out_ch = (int)get_out_channels(); + float* __restrict__ output_ptr = _output.data(); + const float* __restrict__ bias_ptr = this->_bias.data(); + + // Specialized paths for common small channel counts + if (out_ch == 2) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 2; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + } + } + else if (out_ch == 4) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + const float b2 = bias_ptr[2], b3 = bias_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + output_ptr[off + 2] += b2; + output_ptr[off + 3] += b3; + } + } + else + { + for (int f = 0; f < num_frames; f++) + { + float* __restrict__ out_col = output_ptr + f * out_ch; + for (int o = 0; o < out_ch; o++) + { + out_col[o] += bias_ptr[o]; + } + } + } +#else _output.leftCols(num_frames).colwise() += this->_bias; +#endif + } } diff --git a/NAM/film.h b/NAM/film.h index f0f86fb4..d1cc646d 100644 --- a/NAM/film.h +++ b/NAM/film.h @@ -84,6 +84,63 @@ class FiLM _cond_to_scale_shift.process_(condition, num_frames); const auto& scale_shift = _cond_to_scale_shift.GetOutput(); +#ifdef NAM_USE_INLINE_GEMM + // Optimized inline FiLM operation + const int input_dim = (int)get_input_dim(); + const float* __restrict__ input_ptr = input.data(); + const float* __restrict__ scale_shift_ptr = scale_shift.data(); + float* __restrict__ output_ptr = _output.data(); + const int scale_shift_rows = (int)scale_shift.rows(); + const int input_rows = (int)input.rows(); + + if (_do_shift) + { + // scale = top input_dim rows, shift = bottom input_dim rows + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_rows; + const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows; + const float* __restrict__ shift_col = scale_col + input_dim; + float* __restrict__ out_col = output_ptr + f * input_dim; + + int i = 0; + for (; i + 3 < input_dim; i += 4) + { + out_col[i] = in_col[i] * scale_col[i] + shift_col[i]; + out_col[i + 1] = in_col[i + 1] * scale_col[i + 1] + shift_col[i + 1]; + out_col[i + 2] = in_col[i + 2] * scale_col[i + 2] + shift_col[i + 2]; + out_col[i + 3] = in_col[i + 3] * scale_col[i + 3] + shift_col[i + 3]; + } + for (; i < input_dim; i++) + { + out_col[i] = in_col[i] * scale_col[i] + shift_col[i]; + } + } + } + else + { + // scale only + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_rows; + const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows; + float* __restrict__ out_col = output_ptr + f * input_dim; + + int i = 0; + for (; i + 3 < input_dim; i += 4) + { + out_col[i] = in_col[i] * scale_col[i]; + out_col[i + 1] = in_col[i + 1] * scale_col[i + 1]; + out_col[i + 2] = in_col[i + 2] * scale_col[i + 2]; + out_col[i + 3] = in_col[i + 3] * scale_col[i + 3]; + } + for (; i < input_dim; i++) + { + out_col[i] = in_col[i] * scale_col[i]; + } + } + } +#else const auto scale = scale_shift.topRows(get_input_dim()).leftCols(num_frames); if (_do_shift) { @@ -95,6 +152,7 @@ class FiLM { _output.leftCols(num_frames).array() = input.leftCols(num_frames).array() * scale.array(); } +#endif } /// \brief Process input with conditioning (in-place) diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h index 335a9841..adc770f6 100644 --- a/NAM/gating_activations.h +++ b/NAM/gating_activations.h @@ -59,14 +59,43 @@ class GatingActivation void apply(const Eigen::MatrixBase& input, Eigen::MatrixBase& output) { // Validate input dimensions (assert for real-time performance) - const int total_channels = 2 * num_channels; - assert(input.rows() == total_channels); + assert(input.rows() == 2 * num_channels); assert(output.rows() == num_channels); assert(output.cols() == input.cols()); - // Process column-by-column to ensure memory contiguity (important for column-major matrices) - // Uses pre-allocated buffers to avoid allocations in the loop (real-time safe) const int num_samples = input.cols(); + const int input_rows = input.rows(); // 2 * num_channels + +#ifdef NAM_USE_INLINE_GEMM + // Optimized path: direct memory access with activation applied per-element + // Note: output may be a block expression with outer stride != num_channels + const float* __restrict__ input_ptr = input.derived().data(); + float* __restrict__ output_ptr = output.derived().data(); + const int output_stride = (int)output.outerStride(); // Column stride for output + + for (int f = 0; f < num_samples; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_rows; + float* __restrict__ out_col = output_ptr + f * output_stride; + + // Copy input and gating channels to buffers, apply activations, multiply + for (int c = 0; c < num_channels; c++) + { + input_buffer(c, 0) = in_col[c]; + gating_buffer(c, 0) = in_col[c + num_channels]; + } + + input_activation->apply(input_buffer); + gating_activation->apply(gating_buffer); + + // Element-wise multiply and store + for (int c = 0; c < num_channels; c++) + { + out_col[c] = input_buffer(c, 0) * gating_buffer(c, 0); + } + } +#else + // Original Eigen path for (int i = 0; i < num_samples; i++) { // Copy to pre-allocated buffers and apply activations in-place @@ -77,9 +106,9 @@ class GatingActivation gating_activation->apply(gating_buffer); // Element-wise multiplication and store result - // For wavenet compatibility, we assume one-to-one mapping output.block(0, i, num_channels, 1) = input_buffer.array() * gating_buffer.array(); } +#endif } /** @@ -135,14 +164,46 @@ class BlendingActivation void apply(const Eigen::MatrixBase& input, Eigen::MatrixBase& output) { // Validate input dimensions (assert for real-time performance) - const int total_channels = num_channels * 2; // 2*channels in, channels out - assert(input.rows() == total_channels); + assert(input.rows() == num_channels * 2); assert(output.rows() == num_channels); assert(output.cols() == input.cols()); - // Process column-by-column to ensure memory contiguity - // Uses pre-allocated buffers to avoid allocations in the loop (real-time safe) const int num_samples = input.cols(); + const int input_rows = input.rows(); // 2 * num_channels + +#ifdef NAM_USE_INLINE_GEMM + // Optimized path: direct memory access + // Note: output may be a block expression with outer stride != num_channels + const float* __restrict__ input_ptr = input.derived().data(); + float* __restrict__ output_ptr = output.derived().data(); + const int output_stride = (int)output.outerStride(); // Column stride for output + + for (int f = 0; f < num_samples; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_rows; + float* __restrict__ out_col = output_ptr + f * output_stride; + + // Copy channels to buffers + for (int c = 0; c < num_channels; c++) + { + pre_activation_buffer(c, 0) = in_col[c]; + input_buffer(c, 0) = in_col[c]; + blend_buffer(c, 0) = in_col[c + num_channels]; + } + + // Apply activations + input_activation->apply(input_buffer); + blending_activation->apply(blend_buffer); + + // Weighted blending: alpha * activated + (1 - alpha) * pre_activation + for (int c = 0; c < num_channels; c++) + { + const float alpha = blend_buffer(c, 0); + out_col[c] = alpha * input_buffer(c, 0) + (1.0f - alpha) * pre_activation_buffer(c, 0); + } + } +#else + // Original Eigen path for (int i = 0; i < num_samples; i++) { // Store pre-activation input values in buffer @@ -160,6 +221,7 @@ class BlendingActivation output.block(0, i, num_channels, 1) = blend_buffer.array() * input_buffer.array() + (1.0f - blend_buffer.array()) * pre_activation_buffer.array(); } +#endif } /** diff --git a/NAM/ring_buffer.cpp b/NAM/ring_buffer.cpp index 64518b4f..8f0919a8 100644 --- a/NAM/ring_buffer.cpp +++ b/NAM/ring_buffer.cpp @@ -35,22 +35,10 @@ void RingBuffer::Write(const Eigen::MatrixXf& input, const int num_frames) if (NeedsRewind(num_frames)) Rewind(); - // Write the input data at the write position - // NOTE: This function assumes that `input` is a full, pre-allocated MatrixXf - // covering the entire valid buffer range. Callers should not pass Block - // expressions across the API boundary; instead, pass the full buffer and - // slice inside the callee. This avoids Eigen evaluating Blocks into - // temporaries (which would allocate) when binding to MatrixXf. - const int channels = _storage.rows(); - const int copy_cols = num_frames; - - for (int col = 0; col < copy_cols; ++col) - { - for (int row = 0; row < channels; ++row) - { - _storage(row, _write_pos + col) = input(row, col); - } - } + // Write the input data at the write position using Eigen block operations + // This is more efficient than element-by-element copy as it allows + // the compiler to vectorize the operation. + _storage.middleCols(_write_pos, num_frames).noalias() = input.leftCols(num_frames); } Eigen::Block RingBuffer::Read(const int num_frames, const long lookback) diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 7d9b5d03..1ae204ca 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -123,8 +124,34 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma Eigen::MatrixXf& input_mixin_output = this->_input_mixin.GetOutput(); this->_input_mixin_post_film->Process_(input_mixin_output, condition, num_frames); } +#ifdef NAM_USE_INLINE_GEMM + // Optimized matrix addition for small channel counts + { + const int channels = (int)_conv.get_out_channels(); + const float* __restrict__ conv_ptr = _conv.GetOutput().data(); + const float* __restrict__ mixin_ptr = _input_mixin.GetOutput().data(); + float* __restrict__ z_ptr = this->_z.data(); + const int total = channels * num_frames; + + // Unrolled addition + int i = 0; + for (; i + 3 < total; i += 4) + { + z_ptr[i] = conv_ptr[i] + mixin_ptr[i]; + z_ptr[i + 1] = conv_ptr[i + 1] + mixin_ptr[i + 1]; + z_ptr[i + 2] = conv_ptr[i + 2] + mixin_ptr[i + 2]; + z_ptr[i + 3] = conv_ptr[i + 3] + mixin_ptr[i + 3]; + } + for (; i < total; i++) + { + z_ptr[i] = conv_ptr[i] + mixin_ptr[i]; + } + } +#else this->_z.leftCols(num_frames).noalias() = _conv.GetOutput().leftCols(num_frames) + _input_mixin.GetOutput().leftCols(num_frames); +#endif + if (this->_activation_pre_film) { this->_activation_pre_film->Process_(this->_z, condition, num_frames); @@ -207,28 +234,89 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma Eigen::MatrixXf& head1x1_output = this->_head1x1->GetOutput(); this->_head1x1_post_film->Process_(head1x1_output, condition, num_frames); } +#ifdef NAM_USE_INLINE_GEMM + { + // Pure copy - use memcpy + const int total = (int)this->_head1x1->get_out_channels() * num_frames; + std::memcpy(this->_output_head.data(), this->_head1x1->GetOutput().data(), total * sizeof(float)); + } +#else this->_output_head.leftCols(num_frames).noalias() = this->_head1x1->GetOutput().leftCols(num_frames); +#endif } else // No head 1x1 { // (No FiLM) // Store output to head (skip connection: activated conv output) +#ifdef NAM_USE_INLINE_GEMM + if (this->_gating_mode == GatingMode::NONE) + { + // _z has bottleneck rows, data is contiguous - use memcpy + const int total = (int)bottleneck * num_frames; + std::memcpy(this->_output_head.data(), this->_z.data(), total * sizeof(float)); + } + else + { + // _z has 2*bottleneck rows but we only want top bottleneck rows + // Column-major: need to copy column by column with stride + const int out_rows = (int)bottleneck; + const int z_rows = (int)this->_z.rows(); // 2*bottleneck for gated + const float* __restrict__ src = this->_z.data(); + float* __restrict__ dst = this->_output_head.data(); + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ src_col = src + f * z_rows; + float* __restrict__ dst_col = dst + f * out_rows; + for (int r = 0; r < out_rows; r++) + dst_col[r] = src_col[r]; + } + } +#else if (this->_gating_mode == GatingMode::NONE) this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames); else this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames); +#endif } // Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive) if (this->_layer1x1) { +#ifdef NAM_USE_INLINE_GEMM + { + const int channels = (int)this->get_channels(); + const int total = channels * num_frames; + const float* __restrict__ in_ptr = input.data(); + const float* __restrict__ layer_ptr = this->_layer1x1->GetOutput().data(); + float* __restrict__ dst = this->_output_next_layer.data(); + int i = 0; + for (; i + 3 < total; i += 4) + { + dst[i] = in_ptr[i] + layer_ptr[i]; + dst[i + 1] = in_ptr[i + 1] + layer_ptr[i + 1]; + dst[i + 2] = in_ptr[i + 2] + layer_ptr[i + 2]; + dst[i + 3] = in_ptr[i + 3] + layer_ptr[i + 3]; + } + for (; i < total; i++) + dst[i] = in_ptr[i] + layer_ptr[i]; + } +#else this->_output_next_layer.leftCols(num_frames).noalias() = input.leftCols(num_frames) + this->_layer1x1->GetOutput().leftCols(num_frames); +#endif } else { // If layer1x1 is inactive, residual connection is just the input (identity) +#ifdef NAM_USE_INLINE_GEMM + { + // Pure copy - use memcpy + const int total = (int)this->get_channels() * num_frames; + std::memcpy(this->_output_next_layer.data(), input.data(), total * sizeof(float)); + } +#else this->_output_next_layer.leftCols(num_frames).noalias() = input.leftCols(num_frames); +#endif } } @@ -290,8 +378,15 @@ void nam::wavenet::_LayerArray::Process(const Eigen::MatrixXf& layer_inputs, con void nam::wavenet::_LayerArray::Process(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, const Eigen::MatrixXf& head_inputs, const int num_frames) { - // Copy head inputs from previous layer array + // Copy head inputs from previous layer array - use memcpy for pure copy +#ifdef NAM_USE_INLINE_GEMM + { + const int total = (int)this->_head_output_size * num_frames; + std::memcpy(this->_head_inputs.data(), head_inputs.data(), total * sizeof(float)); + } +#else this->_head_inputs.leftCols(num_frames).noalias() = head_inputs.leftCols(num_frames); +#endif ProcessInner(layer_inputs, condition, num_frames); } @@ -320,13 +415,40 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs } // Accumulate head output from this layer +#ifdef NAM_USE_INLINE_GEMM + { + const int channels = (int)this->_head_output_size; + const int total = channels * num_frames; + const float* __restrict__ src = this->_layers[i].GetOutputHead().data(); + float* __restrict__ dst = this->_head_inputs.data(); + int j = 0; + for (; j + 3 < total; j += 4) + { + dst[j] += src[j]; + dst[j + 1] += src[j + 1]; + dst[j + 2] += src[j + 2]; + dst[j + 3] += src[j + 3]; + } + for (; j < total; j++) + dst[j] += src[j]; + } +#else this->_head_inputs.leftCols(num_frames).noalias() += this->_layers[i].GetOutputHead().leftCols(num_frames); +#endif } - // Store output from last layer + // Store output from last layer - use memcpy for pure copy const size_t last_layer = this->_layers.size() - 1; +#ifdef NAM_USE_INLINE_GEMM + { + const int total = (int)this->_get_channels() * num_frames; + std::memcpy(this->_layer_outputs.data(), this->_layers[last_layer].GetOutputNextLayer().data(), + total * sizeof(float)); + } +#else this->_layer_outputs.leftCols(num_frames).noalias() = this->_layers[last_layer].GetOutputNextLayer().leftCols(num_frames); +#endif // Process head rechannel _head_rechannel.process_(this->_head_inputs, num_frames); @@ -558,12 +680,27 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, con auto& final_head_outputs = this->_layer_arrays.back().GetHeadOutputs(); assert(final_head_outputs.rows() == out_channels); - for (int ch = 0; ch < out_channels; ch++) + // Optimized output copy with head_scale multiplication + if (out_channels == 1) { + // Single channel: data is contiguous + const float scale = this->_head_scale; + const float* __restrict__ src = final_head_outputs.data(); + NAM_SAMPLE* __restrict__ dst = output[0]; for (int s = 0; s < num_frames; s++) { - const float out = this->_head_scale * final_head_outputs(ch, s); - output[ch][s] = out; + dst[s] = scale * src[s]; + } + } + else + { + // Multi-channel: rows are not contiguous in column-major + for (int ch = 0; ch < out_channels; ch++) + { + for (int s = 0; s < num_frames; s++) + { + output[ch][s] = this->_head_scale * final_head_outputs(ch, s); + } } } } diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 63e13781..22fc06b4 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -385,11 +385,9 @@ class _Layer std::unique_ptr _layer1x1; // The post-activation 1x1 convolution outputting to the head, optional std::unique_ptr _head1x1; - // The internal state + Eigen::MatrixXf _z; - // Output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive) Eigen::MatrixXf _output_next_layer; - // Output to head (skip connection: activated conv output) Eigen::MatrixXf _output_head; activations::Activation::Ptr _activation; @@ -606,12 +604,12 @@ class _LayerArray // The layer objects std::vector<_Layer> _layers; - // Output from last layer (for next layer array) + Eigen::MatrixXf _layer_outputs; - // Accumulated head inputs from all layers - // Size is _head_output_size (= head1x1.out_channels if head1x1 active, else bottleneck) Eigen::MatrixXf _head_inputs; + // Accumulated head inputs from all layers + // Size is _head_output_size (= head1x1.out_channels if head1x1 active, else bottleneck) // Rechannel for the head (_head_output_size -> head_size) Conv1x1 _head_rechannel; @@ -670,9 +668,9 @@ class WaveNet : public DSP void set_weights_(std::vector::iterator& weights); protected: - // Element-wise arrays: Eigen::MatrixXf _condition_input; Eigen::MatrixXf _condition_output; + std::unique_ptr _condition_dsp; // Temporary buffers for condition DSP processing (to avoid allocations in _process_condition) std::vector> _condition_dsp_input_buffers; diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8118e085..cf5b64f4 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -12,6 +12,7 @@ include_directories(tools ${NAM_DEPS_PATH}/nlohmann) add_executable(loadmodel loadmodel.cpp ${NAM_SOURCES}) add_executable(benchmodel benchmodel.cpp ${NAM_SOURCES}) +add_executable(benchmodel_bufsize benchmodel_bufsize.cpp ${NAM_SOURCES}) add_executable(run_tests run_tests.cpp test/allocation_tracking.cpp ${NAM_SOURCES}) # Compile run_tests without optimizations to ensure allocation tracking works correctly # Also ensure assertions are enabled (NDEBUG is not defined) so tests actually run diff --git a/tools/benchmodel_bufsize.cpp b/tools/benchmodel_bufsize.cpp new file mode 100644 index 00000000..ee27a474 --- /dev/null +++ b/tools/benchmodel_bufsize.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include + +#include "NAM/dsp.h" +#include "NAM/get_dsp.h" + +using std::chrono::duration; +using std::chrono::duration_cast; +using std::chrono::high_resolution_clock; +using std::chrono::microseconds; + +int main(int argc, char* argv[]) +{ + if (argc < 3) + { + std::cerr << "Usage: benchmodel_bufsize [num_iterations]\n"; + exit(1); + } + + const char* modelPath = argv[1]; + const int bufferSize = std::atoi(argv[2]); + const int numIterations = (argc > 3) ? std::atoi(argv[3]) : 5; + + if (bufferSize <= 0 || bufferSize > 4096) + { + std::cerr << "Buffer size must be between 1 and 4096\n"; + exit(1); + } + + // Turn on fast tanh approximation + nam::activations::Activation::enable_fast_tanh(); + + std::unique_ptr model; + model = nam::get_dsp(std::filesystem::path(modelPath)); + + if (model == nullptr) + { + std::cerr << "Failed to load model\n"; + exit(1); + } + + model->Reset(model->GetExpectedSampleRate(), bufferSize); + + // Process 2 seconds of audio + const size_t totalSamples = 48000 * 2; + const size_t numBuffers = totalSamples / bufferSize; + + // Allocate multi-channel buffers + const int in_channels = model->NumInputChannels(); + const int out_channels = model->NumOutputChannels(); + + std::vector> inputBuffers(in_channels); + std::vector> outputBuffers(out_channels); + std::vector inputPtrs(in_channels); + std::vector outputPtrs(out_channels); + + for (int ch = 0; ch < in_channels; ch++) + { + inputBuffers[ch].resize(bufferSize, 0.0); + inputPtrs[ch] = inputBuffers[ch].data(); + } + for (int ch = 0; ch < out_channels; ch++) + { + outputBuffers[ch].resize(bufferSize, 0.0); + outputPtrs[ch] = outputBuffers[ch].data(); + } + + // Warm-up run + for (size_t i = 0; i < numBuffers; i++) + { + model->process(inputPtrs.data(), outputPtrs.data(), bufferSize); + } + + // Timed runs + double totalMicroseconds = 0.0; + for (int iter = 0; iter < numIterations; iter++) + { + auto t1 = high_resolution_clock::now(); + for (size_t i = 0; i < numBuffers; i++) + { + model->process(inputPtrs.data(), outputPtrs.data(), bufferSize); + } + auto t2 = high_resolution_clock::now(); + duration us_double = t2 - t1; + totalMicroseconds += us_double.count(); + } + + double avgMicroseconds = totalMicroseconds / numIterations; + + // Output format: buffer_size,avg_microseconds + std::cout << bufferSize << "," << avgMicroseconds << std::endl; + + return 0; +} From 6c0942b0fe0ce5c3ef2bb4162d6e1bb86094cfaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Feb 2026 15:57:46 -0800 Subject: [PATCH 2/6] Fixed minor bug when building without inline GEMM --- NAM/gating_activations.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h index adc770f6..c0f4ace7 100644 --- a/NAM/gating_activations.h +++ b/NAM/gating_activations.h @@ -64,11 +64,11 @@ class GatingActivation assert(output.cols() == input.cols()); const int num_samples = input.cols(); - const int input_rows = input.rows(); // 2 * num_channels #ifdef NAM_USE_INLINE_GEMM // Optimized path: direct memory access with activation applied per-element // Note: output may be a block expression with outer stride != num_channels + const int input_rows = input.rows(); // 2 * num_channels const float* __restrict__ input_ptr = input.derived().data(); float* __restrict__ output_ptr = output.derived().data(); const int output_stride = (int)output.outerStride(); // Column stride for output @@ -169,11 +169,11 @@ class BlendingActivation assert(output.cols() == input.cols()); const int num_samples = input.cols(); - const int input_rows = input.rows(); // 2 * num_channels #ifdef NAM_USE_INLINE_GEMM // Optimized path: direct memory access // Note: output may be a block expression with outer stride != num_channels + const int input_rows = input.rows(); // 2 * num_channels const float* __restrict__ input_ptr = input.derived().data(); float* __restrict__ output_ptr = output.derived().data(); const int output_stride = (int)output.outerStride(); // Column stride for output From 5d9ed6cb1cefaf6cd9fd56d20fcc38c4972b6a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Thu, 12 Feb 2026 16:00:39 -0800 Subject: [PATCH 3/6] Added NAM_INLINE_GEMM build/test task to CI --- .github/workflows/build.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 62f2a7e2..22874fe0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,3 +41,26 @@ jobs: ./build/tools/run_tests ./build/tools/benchmodel ./example_models/wavenet.nam ./build/tools/benchmodel ./example_models/lstm.nam + + build-ubuntu-inline-gemm: + name: Build Ubuntu (Inline GEMM) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3.3.0 + with: + submodules: recursive + + - name: Build Tools + working-directory: ${{github.workspace}}/build + env: + CXX: clang++ + run: | + cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-DNAM_USE_INLINE_GEMM" + cmake --build . -j4 + + - name: Run tests + working-directory: ${{github.workspace}} + run: | + ./build/tools/run_tests + ./build/tools/benchmodel ./example_models/wavenet.nam + ./build/tools/benchmodel ./example_models/lstm.nam From 7844a41f54bef225bc6c96753c96e9981f6c082e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Mon, 16 Feb 2026 10:38:14 -0800 Subject: [PATCH 4/6] Remove redundant manual loop unrolling from activations and element-wise ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ARM assembly analysis (-O2 -DNDEBUG) confirmed: - GCC auto-unrolls simple activation loops; manual 4-wide gives no benefit - expf() serializes sigmoid/SiLU; unrolling can't help - Eigen element-wise ops (.leftCols + .leftCols) produce identical codegen to raw float* loops when assertions are disabled Simplify 5 activation classes to use inline helpers (relu, sigmoid, etc.) and revert 3 wavenet element-wise operations back to Eigen expressions. Inline GEMM (Conv1x1/Conv1D), depthwise unrolling, FiLM unrolling, bias broadcast, and memcpy optimizations are retained — those show measurable wins on both desktop and Cortex-M7. Also restored comments that were accidentally removed from wavenet.h. --- NAM/activations.cpp | 1 + NAM/activations.h | 102 +++----------------------------------------- NAM/wavenet.cpp | 64 --------------------------- NAM/wavenet.h | 12 +++--- 4 files changed, 14 insertions(+), 165 deletions(-) diff --git a/NAM/activations.cpp b/NAM/activations.cpp index 8d37a19a..3e0bc944 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -35,6 +35,7 @@ std::unordered_map nam::activati {"PReLU", make_singleton_ptr(_PRELU)}, {"Softsign", make_singleton_ptr(_SOFTSIGN)}}; +// Variables to hold previous instances of activations when we replace them with fast or LUT implementations nam::activations::Activation::Ptr tanh_bak = nullptr; nam::activations::Activation::Ptr sigmoid_bak = nullptr; nam::activations::Activation::Ptr silu_bak = nullptr; diff --git a/NAM/activations.h b/NAM/activations.h index 9b0c47ca..ae638878 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -235,24 +235,8 @@ class ActivationReLU : public Activation public: void apply(float* data, long size) override { - // Optimized ReLU with loop unrolling - long pos = 0; - // Process 4 elements at a time - for (; pos + 3 < size; pos += 4) - { - // Branchless ReLU using conditional - const float v0 = data[pos], v1 = data[pos + 1]; - const float v2 = data[pos + 2], v3 = data[pos + 3]; - data[pos] = v0 > 0.0f ? v0 : 0.0f; - data[pos + 1] = v1 > 0.0f ? v1 : 0.0f; - data[pos + 2] = v2 > 0.0f ? v2 : 0.0f; - data[pos + 3] = v3 > 0.0f ? v3 : 0.0f; - } - // Handle remainder - for (; pos < size; pos++) - { - data[pos] = data[pos] > 0.0f ? data[pos] : 0.0f; - } + for (long pos = 0; pos < size; pos++) + data[pos] = relu(data[pos]); } }; @@ -316,23 +300,8 @@ class ActivationSigmoid : public Activation public: void apply(float* data, long size) override { - long pos = 0; - // Process 4 elements at a time - for (; pos + 3 < size; pos += 4) - { - const float x0 = data[pos], x1 = data[pos + 1]; - const float x2 = data[pos + 2], x3 = data[pos + 3]; - - data[pos] = 1.0f / (1.0f + expf(-x0)); - data[pos + 1] = 1.0f / (1.0f + expf(-x1)); - data[pos + 2] = 1.0f / (1.0f + expf(-x2)); - data[pos + 3] = 1.0f / (1.0f + expf(-x3)); - } - // Handle remainder - for (; pos < size; pos++) - { + for (long pos = 0; pos < size; pos++) data[pos] = sigmoid(data[pos]); - } } }; @@ -341,28 +310,8 @@ class ActivationSwish : public Activation public: void apply(float* data, long size) override { - long pos = 0; - // Process 4 elements at a time: swish(x) = x * sigmoid(x) = x / (1 + exp(-x)) - for (; pos + 3 < size; pos += 4) - { - const float x0 = data[pos], x1 = data[pos + 1]; - const float x2 = data[pos + 2], x3 = data[pos + 3]; - - const float s0 = 1.0f / (1.0f + expf(-x0)); - const float s1 = 1.0f / (1.0f + expf(-x1)); - const float s2 = 1.0f / (1.0f + expf(-x2)); - const float s3 = 1.0f / (1.0f + expf(-x3)); - - data[pos] = x0 * s0; - data[pos + 1] = x1 * s1; - data[pos + 2] = x2 * s2; - data[pos + 3] = x3 * s3; - } - // Handle remainder - for (; pos < size; pos++) - { + for (long pos = 0; pos < size; pos++) data[pos] = swish(data[pos]); - } } }; @@ -371,32 +320,8 @@ class ActivationHardSwish : public Activation public: void apply(float* data, long size) override { - const float inv6 = 1.0f / 6.0f; - long pos = 0; - // Process 4 elements at a time - for (; pos + 3 < size; pos += 4) - { - const float x0 = data[pos], x1 = data[pos + 1]; - const float x2 = data[pos + 2], x3 = data[pos + 3]; - - const float t0 = x0 + 3.0f, t1 = x1 + 3.0f; - const float t2 = x2 + 3.0f, t3 = x3 + 3.0f; - - const float c0 = t0 < 0.0f ? 0.0f : (t0 > 6.0f ? 6.0f : t0); - const float c1 = t1 < 0.0f ? 0.0f : (t1 > 6.0f ? 6.0f : t1); - const float c2 = t2 < 0.0f ? 0.0f : (t2 > 6.0f ? 6.0f : t2); - const float c3 = t3 < 0.0f ? 0.0f : (t3 > 6.0f ? 6.0f : t3); - - data[pos] = x0 * c0 * inv6; - data[pos + 1] = x1 * c1 * inv6; - data[pos + 2] = x2 * c2 * inv6; - data[pos + 3] = x3 * c3 * inv6; - } - // Handle remainder - for (; pos < size; pos++) - { + for (long pos = 0; pos < size; pos++) data[pos] = hardswish(data[pos]); - } } }; @@ -405,23 +330,8 @@ class ActivationSoftsign : public Activation public: void apply(float* data, long size) override { - long pos = 0; - // Process 4 elements at a time - for (; pos + 3 < size; pos += 4) - { - const float x0 = data[pos], x1 = data[pos + 1]; - const float x2 = data[pos + 2], x3 = data[pos + 3]; - - data[pos] = x0 / (1.0f + fabsf(x0)); - data[pos + 1] = x1 / (1.0f + fabsf(x1)); - data[pos + 2] = x2 / (1.0f + fabsf(x2)); - data[pos + 3] = x3 / (1.0f + fabsf(x3)); - } - // Handle remainder - for (; pos < size; pos++) - { + for (long pos = 0; pos < size; pos++) data[pos] = softsign(data[pos]); - } } }; diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 1ae204ca..6811a07e 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -124,33 +124,8 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma Eigen::MatrixXf& input_mixin_output = this->_input_mixin.GetOutput(); this->_input_mixin_post_film->Process_(input_mixin_output, condition, num_frames); } -#ifdef NAM_USE_INLINE_GEMM - // Optimized matrix addition for small channel counts - { - const int channels = (int)_conv.get_out_channels(); - const float* __restrict__ conv_ptr = _conv.GetOutput().data(); - const float* __restrict__ mixin_ptr = _input_mixin.GetOutput().data(); - float* __restrict__ z_ptr = this->_z.data(); - const int total = channels * num_frames; - - // Unrolled addition - int i = 0; - for (; i + 3 < total; i += 4) - { - z_ptr[i] = conv_ptr[i] + mixin_ptr[i]; - z_ptr[i + 1] = conv_ptr[i + 1] + mixin_ptr[i + 1]; - z_ptr[i + 2] = conv_ptr[i + 2] + mixin_ptr[i + 2]; - z_ptr[i + 3] = conv_ptr[i + 3] + mixin_ptr[i + 3]; - } - for (; i < total; i++) - { - z_ptr[i] = conv_ptr[i] + mixin_ptr[i]; - } - } -#else this->_z.leftCols(num_frames).noalias() = _conv.GetOutput().leftCols(num_frames) + _input_mixin.GetOutput().leftCols(num_frames); -#endif if (this->_activation_pre_film) { @@ -282,28 +257,8 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma // Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive) if (this->_layer1x1) { -#ifdef NAM_USE_INLINE_GEMM - { - const int channels = (int)this->get_channels(); - const int total = channels * num_frames; - const float* __restrict__ in_ptr = input.data(); - const float* __restrict__ layer_ptr = this->_layer1x1->GetOutput().data(); - float* __restrict__ dst = this->_output_next_layer.data(); - int i = 0; - for (; i + 3 < total; i += 4) - { - dst[i] = in_ptr[i] + layer_ptr[i]; - dst[i + 1] = in_ptr[i + 1] + layer_ptr[i + 1]; - dst[i + 2] = in_ptr[i + 2] + layer_ptr[i + 2]; - dst[i + 3] = in_ptr[i + 3] + layer_ptr[i + 3]; - } - for (; i < total; i++) - dst[i] = in_ptr[i] + layer_ptr[i]; - } -#else this->_output_next_layer.leftCols(num_frames).noalias() = input.leftCols(num_frames) + this->_layer1x1->GetOutput().leftCols(num_frames); -#endif } else { @@ -415,26 +370,7 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs } // Accumulate head output from this layer -#ifdef NAM_USE_INLINE_GEMM - { - const int channels = (int)this->_head_output_size; - const int total = channels * num_frames; - const float* __restrict__ src = this->_layers[i].GetOutputHead().data(); - float* __restrict__ dst = this->_head_inputs.data(); - int j = 0; - for (; j + 3 < total; j += 4) - { - dst[j] += src[j]; - dst[j + 1] += src[j + 1]; - dst[j + 2] += src[j + 2]; - dst[j + 3] += src[j + 3]; - } - for (; j < total; j++) - dst[j] += src[j]; - } -#else this->_head_inputs.leftCols(num_frames).noalias() += this->_layers[i].GetOutputHead().leftCols(num_frames); -#endif } // Store output from last layer - use memcpy for pure copy diff --git a/NAM/wavenet.h b/NAM/wavenet.h index 22fc06b4..63e13781 100644 --- a/NAM/wavenet.h +++ b/NAM/wavenet.h @@ -385,9 +385,11 @@ class _Layer std::unique_ptr _layer1x1; // The post-activation 1x1 convolution outputting to the head, optional std::unique_ptr _head1x1; - + // The internal state Eigen::MatrixXf _z; + // Output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive) Eigen::MatrixXf _output_next_layer; + // Output to head (skip connection: activated conv output) Eigen::MatrixXf _output_head; activations::Activation::Ptr _activation; @@ -604,12 +606,12 @@ class _LayerArray // The layer objects std::vector<_Layer> _layers; - + // Output from last layer (for next layer array) Eigen::MatrixXf _layer_outputs; - Eigen::MatrixXf _head_inputs; - // Accumulated head inputs from all layers // Size is _head_output_size (= head1x1.out_channels if head1x1 active, else bottleneck) + Eigen::MatrixXf _head_inputs; + // Rechannel for the head (_head_output_size -> head_size) Conv1x1 _head_rechannel; @@ -668,9 +670,9 @@ class WaveNet : public DSP void set_weights_(std::vector::iterator& weights); protected: + // Element-wise arrays: Eigen::MatrixXf _condition_input; Eigen::MatrixXf _condition_output; - std::unique_ptr _condition_dsp; // Temporary buffers for condition DSP processing (to avoid allocations in _process_condition) std::vector> _condition_dsp_input_buffers; From 75db6b9766acc924245019f4601cafa8b3741260 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Mon, 16 Feb 2026 11:05:38 -0800 Subject: [PATCH 5/6] Added option to disable fasttanh to benchmodel and benchmodel_bufsize --- tools/benchmodel.cpp | 120 +++++++++++++++++++---------------- tools/benchmodel_bufsize.cpp | 26 ++++++-- 2 files changed, 87 insertions(+), 59 deletions(-) diff --git a/tools/benchmodel.cpp b/tools/benchmodel.cpp index 39c14b0e..5e8c45d5 100644 --- a/tools/benchmodel.cpp +++ b/tools/benchmodel.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include "NAM/dsp.h" @@ -17,74 +18,85 @@ double outputBuffer[AUDIO_BUFFER_SIZE]; int main(int argc, char* argv[]) { - if (argc > 1) + if (argc < 2) { - const char* modelPath = argv[1]; - - std::cout << "Loading model " << modelPath << "\n"; + std::cerr << "Usage: benchmodel [--no-fast-tanh]\n"; + exit(1); + } - // Turn on fast tanh approximation - nam::activations::Activation::enable_fast_tanh(); + const char* modelPath = argv[1]; - std::unique_ptr model; + // Check for --no-fast-tanh flag + bool useFastTanh = true; + for (int i = 2; i < argc; i++) + { + if (std::strcmp(argv[i], "--no-fast-tanh") == 0) + useFastTanh = false; + } - model.reset(); - model = nam::get_dsp(std::filesystem::path(modelPath)); + if (useFastTanh) + { + nam::activations::Activation::enable_fast_tanh(); + std::cout << "Fast tanh: enabled\n"; + } + else + { + nam::activations::Activation::disable_fast_tanh(); + std::cout << "Fast tanh: disabled\n"; + } - if (model == nullptr) - { - std::cerr << "Failed to load model\n"; + std::cout << "Loading model " << modelPath << "\n"; - exit(1); - } + std::unique_ptr model; + model = nam::get_dsp(std::filesystem::path(modelPath)); - size_t bufferSize = AUDIO_BUFFER_SIZE; - model->Reset(model->GetExpectedSampleRate(), bufferSize); - size_t numBuffers = (48000 / bufferSize) * 2; + if (model == nullptr) + { + std::cerr << "Failed to load model\n"; + exit(1); + } - // Allocate multi-channel buffers - const int in_channels = model->NumInputChannels(); - const int out_channels = model->NumOutputChannels(); + size_t bufferSize = AUDIO_BUFFER_SIZE; + model->Reset(model->GetExpectedSampleRate(), bufferSize); + size_t numBuffers = (48000 / bufferSize) * 2; - std::vector> inputBuffers(in_channels); - std::vector> outputBuffers(out_channels); - std::vector inputPtrs(in_channels); - std::vector outputPtrs(out_channels); + // Allocate multi-channel buffers + const int in_channels = model->NumInputChannels(); + const int out_channels = model->NumOutputChannels(); - for (int ch = 0; ch < in_channels; ch++) - { - inputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); - inputPtrs[ch] = inputBuffers[ch].data(); - } - for (int ch = 0; ch < out_channels; ch++) - { - outputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); - outputPtrs[ch] = outputBuffers[ch].data(); - } + std::vector> inputBuffers(in_channels); + std::vector> outputBuffers(out_channels); + std::vector inputPtrs(in_channels); + std::vector outputPtrs(out_channels); - std::cout << "Running benchmark\n"; - auto t1 = high_resolution_clock::now(); - for (size_t i = 0; i < numBuffers; i++) - { - model->process(inputPtrs.data(), outputPtrs.data(), AUDIO_BUFFER_SIZE); - } - auto t2 = high_resolution_clock::now(); - std::cout << "Finished\n"; + for (int ch = 0; ch < in_channels; ch++) + { + inputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); + inputPtrs[ch] = inputBuffers[ch].data(); + } + for (int ch = 0; ch < out_channels; ch++) + { + outputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); + outputPtrs[ch] = outputBuffers[ch].data(); + } + std::cout << "Running benchmark\n"; + auto t1 = high_resolution_clock::now(); + for (size_t i = 0; i < numBuffers; i++) + { + model->process(inputPtrs.data(), outputPtrs.data(), AUDIO_BUFFER_SIZE); + } + auto t2 = high_resolution_clock::now(); + std::cout << "Finished\n"; - /* Getting number of milliseconds as an integer. */ - auto ms_int = duration_cast(t2 - t1); + /* Getting number of milliseconds as an integer. */ + auto ms_int = duration_cast(t2 - t1); - /* Getting number of milliseconds as a double. */ - duration ms_double = t2 - t1; + /* Getting number of milliseconds as a double. */ + duration ms_double = t2 - t1; - std::cout << ms_int.count() << "ms\n"; - std::cout << ms_double.count() << "ms\n"; - } - else - { - std::cerr << "Usage: benchmodel \n"; - } + std::cout << ms_int.count() << "ms\n"; + std::cout << ms_double.count() << "ms\n"; - exit(0); + return 0; } diff --git a/tools/benchmodel_bufsize.cpp b/tools/benchmodel_bufsize.cpp index ee27a474..355b6056 100644 --- a/tools/benchmodel_bufsize.cpp +++ b/tools/benchmodel_bufsize.cpp @@ -1,7 +1,8 @@ #include #include -#include #include +#include +#include #include "NAM/dsp.h" #include "NAM/get_dsp.h" @@ -11,17 +12,30 @@ using std::chrono::duration_cast; using std::chrono::high_resolution_clock; using std::chrono::microseconds; +/* A version of benchmodel that accepts an arbitrary buffer size between 1 and 4096. + * Useful for testing the effect of smaller/larger buffers on performance. + */ + int main(int argc, char* argv[]) { if (argc < 3) { - std::cerr << "Usage: benchmodel_bufsize [num_iterations]\n"; + std::cerr << "Usage: benchmodel_bufsize [num_iterations] [--no-fast-tanh]\n"; exit(1); } const char* modelPath = argv[1]; const int bufferSize = std::atoi(argv[2]); - const int numIterations = (argc > 3) ? std::atoi(argv[3]) : 5; + int numIterations = 5; + bool useFastTanh = true; + + for (int i = 3; i < argc; i++) + { + if (std::strcmp(argv[i], "--no-fast-tanh") == 0) + useFastTanh = false; + else + numIterations = std::atoi(argv[i]); + } if (bufferSize <= 0 || bufferSize > 4096) { @@ -29,8 +43,10 @@ int main(int argc, char* argv[]) exit(1); } - // Turn on fast tanh approximation - nam::activations::Activation::enable_fast_tanh(); + if (useFastTanh) + nam::activations::Activation::enable_fast_tanh(); + else + nam::activations::Activation::disable_fast_tanh(); std::unique_ptr model; model = nam::get_dsp(std::filesystem::path(modelPath)); From f3234921b2e94d225c9d8b9bac2cdb073ca25d19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Felipe=20Santos?= Date: Tue, 17 Feb 2026 18:24:24 -0800 Subject: [PATCH 6/6] Fixed bugs and added tests to ensure operations work on non-contiguous matrices --- NAM/activations.h | 11 +- NAM/dsp.cpp | 80 ++-- NAM/film.h | 8 +- NAM/gating_activations.h | 14 +- tools/run_tests.cpp | 15 + tools/test/test_noncontiguous_blocks.cpp | 529 +++++++++++++++++++++++ 6 files changed, 610 insertions(+), 47 deletions(-) create mode 100644 tools/test/test_noncontiguous_blocks.cpp diff --git a/NAM/activations.h b/NAM/activations.h index ae638878..70d8893d 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -139,9 +139,18 @@ class Activation Activation() = default; virtual ~Activation() = default; virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); } - virtual void apply(Eigen::Block block) { apply(block.data(), block.rows() * block.cols()); } + virtual void apply(Eigen::Block block) + { + // Block must be contiguous in memory (outerStride == rows) for flat data() access. + // Non-contiguous blocks (e.g. topRows() of a wider matrix) would read/write wrong elements. + assert(block.outerStride() == block.rows()); + apply(block.data(), block.rows() * block.cols()); + } virtual void apply(Eigen::Block block) { + // Inner-panel blocks (e.g. leftCols()) are always contiguous for column-major matrices, + // but assert anyway for safety. + assert(block.outerStride() == block.rows()); apply(block.data(), block.rows() * block.cols()); } virtual void apply(float* data, long size) {} diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index e42f8e72..4ba90c6b 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -462,6 +462,9 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float* __restrict__ input_ptr = input.data(); const float* __restrict__ weight_ptr = this->_weight.data(); float* __restrict__ output_ptr = _output.data(); + // Use outerStride() instead of in_ch to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int in_stride = (int)input.outerStride(); // Specialized paths for common small sizes if (out_ch == 2 && in_ch == 1) @@ -469,7 +472,7 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w0 = weight_ptr[0], w1 = weight_ptr[1]; for (int f = 0; f < num_frames; f++) { - const float in_val = input_ptr[f]; + const float in_val = input_ptr[f * in_stride]; output_ptr[f * 2] = w0 * in_val; output_ptr[f * 2 + 1] = w1 * in_val; } @@ -480,7 +483,7 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w2 = weight_ptr[2], w3 = weight_ptr[3]; for (int f = 0; f < num_frames; f++) { - const float in_val = input_ptr[f]; + const float in_val = input_ptr[f * in_stride]; output_ptr[f * 4] = w0 * in_val; output_ptr[f * 4 + 1] = w1 * in_val; output_ptr[f * 4 + 2] = w2 * in_val; @@ -492,7 +495,8 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w0 = weight_ptr[0], w1 = weight_ptr[1]; for (int f = 0; f < num_frames; f++) { - output_ptr[f] = w0 * input_ptr[f * 2] + w1 * input_ptr[f * 2 + 1]; + const float* __restrict__ in_col = input_ptr + f * in_stride; + output_ptr[f] = w0 * in_col[0] + w1 * in_col[1]; } } else if (out_ch == 2 && in_ch == 2) @@ -502,11 +506,11 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w01 = weight_ptr[2], w11 = weight_ptr[3]; for (int f = 0; f < num_frames; f++) { - const int off = f * 2; - const float i0 = input_ptr[off]; - const float i1 = input_ptr[off + 1]; - output_ptr[off] = w00 * i0 + w01 * i1; - output_ptr[off + 1] = w10 * i0 + w11 * i1; + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + output_ptr[f * 2] = w00 * i0 + w01 * i1; + output_ptr[f * 2 + 1] = w10 * i0 + w11 * i1; } } else if (out_ch == 2 && in_ch == 4) @@ -517,10 +521,11 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w03 = weight_ptr[6], w13 = weight_ptr[7]; for (int f = 0; f < num_frames; f++) { - const float i0 = input_ptr[f * 4]; - const float i1 = input_ptr[f * 4 + 1]; - const float i2 = input_ptr[f * 4 + 2]; - const float i3 = input_ptr[f * 4 + 3]; + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + const float i3 = in_col[3]; output_ptr[f * 2] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; output_ptr[f * 2 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; } @@ -531,9 +536,9 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w2 = weight_ptr[2], w3 = weight_ptr[3]; for (int f = 0; f < num_frames; f++) { - const int off = f * 4; - output_ptr[f] = w0 * input_ptr[off] + w1 * input_ptr[off + 1] - + w2 * input_ptr[off + 2] + w3 * input_ptr[off + 3]; + const float* __restrict__ in_col = input_ptr + f * in_stride; + output_ptr[f] = w0 * in_col[0] + w1 * in_col[1] + + w2 * in_col[2] + w3 * in_col[3]; } } else if (out_ch == 4 && in_ch == 2) @@ -542,8 +547,9 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; for (int f = 0; f < num_frames; f++) { - const float i0 = input_ptr[f * 2]; - const float i1 = input_ptr[f * 2 + 1]; + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; output_ptr[f * 4] = w00 * i0 + w01 * i1; output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1; output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1; @@ -557,13 +563,13 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; for (int f = 0; f < num_frames; f++) { - const int off = f * 3; - const float i0 = input_ptr[off]; - const float i1 = input_ptr[off + 1]; - const float i2 = input_ptr[off + 2]; - output_ptr[off] = w00 * i0 + w01 * i1 + w02 * i2; - output_ptr[off + 1] = w10 * i0 + w11 * i1 + w12 * i2; - output_ptr[off + 2] = w20 * i0 + w21 * i1 + w22 * i2; + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + output_ptr[f * 3] = w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[f * 3 + 1] = w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[f * 3 + 2] = w20 * i0 + w21 * i1 + w22 * i2; } } else if (out_ch == 4 && in_ch == 4) @@ -574,22 +580,22 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; for (int f = 0; f < num_frames; f++) { - const int off = f * 4; - const float i0 = input_ptr[off]; - const float i1 = input_ptr[off + 1]; - const float i2 = input_ptr[off + 2]; - const float i3 = input_ptr[off + 3]; - output_ptr[off] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; - output_ptr[off + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; - output_ptr[off + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; - output_ptr[off + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + const float i3 = in_col[3]; + output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; } } else if (out_ch == 6 && in_ch == 6) { for (int f = 0; f < num_frames; f++) { - const float* __restrict__ in_col = input_ptr + f * 6; + const float* __restrict__ in_col = input_ptr + f * in_stride; float* __restrict__ out_col = output_ptr + f * 6; const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; @@ -604,7 +610,7 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons { for (int f = 0; f < num_frames; f++) { - const float* __restrict__ in_col = input_ptr + f * 8; + const float* __restrict__ in_col = input_ptr + f * in_stride; float* __restrict__ out_col = output_ptr + f * 8; const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; @@ -619,7 +625,7 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons { for (int f = 0; f < num_frames; f++) { - const float* __restrict__ in_col = input_ptr + f * 8; + const float* __restrict__ in_col = input_ptr + f * in_stride; float* __restrict__ out_col = output_ptr + f * 4; const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; @@ -634,7 +640,7 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons { for (int f = 0; f < num_frames; f++) { - const float* __restrict__ in_col = input_ptr + f * 4; + const float* __restrict__ in_col = input_ptr + f * in_stride; float* __restrict__ out_col = output_ptr + f * 8; const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; for (int o = 0; o < 8; o++) diff --git a/NAM/film.h b/NAM/film.h index d1cc646d..559de131 100644 --- a/NAM/film.h +++ b/NAM/film.h @@ -91,14 +91,16 @@ class FiLM const float* __restrict__ scale_shift_ptr = scale_shift.data(); float* __restrict__ output_ptr = _output.data(); const int scale_shift_rows = (int)scale_shift.rows(); - const int input_rows = (int)input.rows(); + // Use outerStride() instead of rows() to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int input_stride = (int)input.outerStride(); if (_do_shift) { // scale = top input_dim rows, shift = bottom input_dim rows for (int f = 0; f < num_frames; f++) { - const float* __restrict__ in_col = input_ptr + f * input_rows; + const float* __restrict__ in_col = input_ptr + f * input_stride; const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows; const float* __restrict__ shift_col = scale_col + input_dim; float* __restrict__ out_col = output_ptr + f * input_dim; @@ -122,7 +124,7 @@ class FiLM // scale only for (int f = 0; f < num_frames; f++) { - const float* __restrict__ in_col = input_ptr + f * input_rows; + const float* __restrict__ in_col = input_ptr + f * input_stride; const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows; float* __restrict__ out_col = output_ptr + f * input_dim; diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h index c0f4ace7..0d52298c 100644 --- a/NAM/gating_activations.h +++ b/NAM/gating_activations.h @@ -67,15 +67,16 @@ class GatingActivation #ifdef NAM_USE_INLINE_GEMM // Optimized path: direct memory access with activation applied per-element - // Note: output may be a block expression with outer stride != num_channels - const int input_rows = input.rows(); // 2 * num_channels + // Use outerStride() instead of rows() to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int input_stride = (int)input.outerStride(); const float* __restrict__ input_ptr = input.derived().data(); float* __restrict__ output_ptr = output.derived().data(); const int output_stride = (int)output.outerStride(); // Column stride for output for (int f = 0; f < num_samples; f++) { - const float* __restrict__ in_col = input_ptr + f * input_rows; + const float* __restrict__ in_col = input_ptr + f * input_stride; float* __restrict__ out_col = output_ptr + f * output_stride; // Copy input and gating channels to buffers, apply activations, multiply @@ -172,15 +173,16 @@ class BlendingActivation #ifdef NAM_USE_INLINE_GEMM // Optimized path: direct memory access - // Note: output may be a block expression with outer stride != num_channels - const int input_rows = input.rows(); // 2 * num_channels + // Use outerStride() instead of rows() to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int input_stride = (int)input.outerStride(); const float* __restrict__ input_ptr = input.derived().data(); float* __restrict__ output_ptr = output.derived().data(); const int output_stride = (int)output.outerStride(); // Column stride for output for (int f = 0; f < num_samples; f++) { - const float* __restrict__ in_col = input_ptr + f * input_rows; + const float* __restrict__ in_col = input_ptr + f * input_stride; float* __restrict__ out_col = output_ptr + f * output_stride; // Copy channels to buffers diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 9b3bdec7..30d9c5bb 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -25,6 +25,7 @@ #include "test/test_input_buffer_verification.cpp" #include "test/test_lstm.cpp" #include "test/test_wavenet_configurable_gating.cpp" +#include "test/test_noncontiguous_blocks.cpp" int main() { @@ -235,6 +236,20 @@ int main() // Configurable gating/blending tests run_configurable_gating_tests(); + // Non-contiguous block correctness tests (outerStride != rows) + test_noncontiguous_blocks::test_conv1x1_process_toprows(); + test_noncontiguous_blocks::test_conv1x1_process_toprows_with_bias(); + test_noncontiguous_blocks::test_conv1x1_process_toprows_2x2(); + test_noncontiguous_blocks::test_conv1x1_process_toprows_4x4(); + test_noncontiguous_blocks::test_conv1x1_toprows_matches_contiguous(); + test_noncontiguous_blocks::test_film_process_toprows_with_shift(); + test_noncontiguous_blocks::test_film_process_toprows_scale_only(); + test_noncontiguous_blocks::test_film_toprows_matches_contiguous(); + test_noncontiguous_blocks::test_film_process_inplace_toprows(); + test_noncontiguous_blocks::test_gating_output_toprows(); + test_noncontiguous_blocks::test_gating_toprows_matches_contiguous(); + test_noncontiguous_blocks::test_blending_output_toprows(); + test_get_dsp::test_gets_input_level(); test_get_dsp::test_gets_output_level(); test_get_dsp::test_null_input_level(); diff --git a/tools/test/test_noncontiguous_blocks.cpp b/tools/test/test_noncontiguous_blocks.cpp new file mode 100644 index 00000000..7044fd10 --- /dev/null +++ b/tools/test/test_noncontiguous_blocks.cpp @@ -0,0 +1,529 @@ +// Tests for correct handling of non-contiguous Eigen block expressions +// (e.g. topRows()) in inline GEMM paths. +// +// When a matrix expression like matrix.topRows(n) is passed to a function +// that accesses raw pointers via .data(), the outerStride() (distance between +// columns) may be larger than rows(). Code that assumes stride == rows will +// read/write wrong memory locations. + +#include +#include +#include +#include +#include + +#include "NAM/activations.h" +#include "NAM/dsp.h" +#include "NAM/film.h" +#include "NAM/gating_activations.h" + +namespace test_noncontiguous_blocks +{ + +// Helper: create identity Conv1x1 weights (identity matrix + zero bias) +static std::vector make_identity_weights(int channels, bool bias) +{ + std::vector weights(channels * channels + (bias ? channels : 0), 0.0f); + // Column-major identity matrix + for (int i = 0; i < channels; i++) + weights[i * channels + i] = 1.0f; + return weights; +} + +// Helper: create scaled Conv1x1 weights (diagonal scale matrix + zero bias) +static std::vector make_scale_weights(int in_ch, int out_ch, float scale, bool bias) +{ + std::vector weights(out_ch * in_ch + (bias ? out_ch : 0), 0.0f); + // Set all weight elements to scale (for simple testing) + for (int o = 0; o < out_ch; o++) + for (int i = 0; i < in_ch; i++) + weights[i * out_ch + o] = (i == o) ? scale : 0.0f; + return weights; +} + +// ============================================================ +// Conv1x1::process_() with non-contiguous input (topRows) +// ============================================================ + +void test_conv1x1_process_toprows() +{ + // Create a matrix with 8 rows, but only pass topRows(4) to Conv1x1 + // This simulates the gated activation path in wavenet.cpp + const int total_rows = 8; + const int bottleneck = 4; + const int num_frames = 3; + + // Create Conv1x1: 4 in -> 4 out, identity weights + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/false); + auto weights = make_identity_weights(bottleneck, false); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + // Create full matrix with known values + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c * 10 + f); + + // Process only topRows(bottleneck) - this has outerStride = 8, rows = 4 + auto top_block = full_matrix.topRows(bottleneck); + assert(top_block.outerStride() == total_rows); // Verify non-contiguous + assert(top_block.rows() == bottleneck); + + conv.process_(top_block, num_frames); + const auto& output = conv.GetOutput(); + + // With identity weights, output should equal the top rows exactly + for (int c = 0; c < bottleneck; c++) + { + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f); + const float actual = output(c, f); + assert(std::abs(actual - expected) < 1e-6f); + } + } +} + +void test_conv1x1_process_toprows_with_bias() +{ + const int total_rows = 6; + const int bottleneck = 3; + const int num_frames = 4; + + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/true); + auto weights = make_identity_weights(bottleneck, false); + // Add bias values + weights.push_back(10.0f); + weights.push_back(20.0f); + weights.push_back(30.0f); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c + 1); + + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& output = conv.GetOutput(); + + const float biases[] = {10.0f, 20.0f, 30.0f}; + for (int c = 0; c < bottleneck; c++) + { + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f) + biases[c]; + assert(std::abs(output(c, f) - expected) < 1e-6f); + } + } +} + +void test_conv1x1_process_toprows_2x2() +{ + // Test specific 2x2 specialized path with non-contiguous input + const int total_rows = 4; // doubled for gating + const int bottleneck = 2; + const int num_frames = 3; + + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/false); + // Weight: [[2, 0], [0, 3]] (column-major: [2, 0, 0, 3]) + std::vector weights = {2.0f, 0.0f, 0.0f, 3.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix << 1.0f, 2.0f, 3.0f, // row 0 (top, used) + 4.0f, 5.0f, 6.0f, // row 1 (top, used) + 99.0f, 99.0f, 99.0f, // row 2 (bottom, NOT used) + 99.0f, 99.0f, 99.0f; // row 3 (bottom, NOT used) + + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& output = conv.GetOutput(); + + // output = [[2,0],[0,3]] * topRows(2) + // Frame 0: [2*1, 3*4] = [2, 12] + // Frame 1: [2*2, 3*5] = [4, 15] + // Frame 2: [2*3, 3*6] = [6, 18] + assert(std::abs(output(0, 0) - 2.0f) < 1e-6f); + assert(std::abs(output(1, 0) - 12.0f) < 1e-6f); + assert(std::abs(output(0, 1) - 4.0f) < 1e-6f); + assert(std::abs(output(1, 1) - 15.0f) < 1e-6f); + assert(std::abs(output(0, 2) - 6.0f) < 1e-6f); + assert(std::abs(output(1, 2) - 18.0f) < 1e-6f); +} + +void test_conv1x1_process_toprows_4x4() +{ + // Test specific 4x4 specialized path with non-contiguous input + const int total_rows = 8; + const int bottleneck = 4; + const int num_frames = 2; + + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/false); + auto weights = make_identity_weights(bottleneck, false); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + // Top 4 rows: known values; bottom 4 rows: poison values + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c * 10 + f + 1); + for (int c = bottleneck; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = -999.0f; // poison + + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& output = conv.GetOutput(); + + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(output(c, f) - full_matrix(c, f)) < 1e-6f); +} + +// Verify Conv1x1 with topRows matches result from contiguous copy +void test_conv1x1_toprows_matches_contiguous() +{ + const int total_rows = 8; + const int bottleneck = 4; + const int num_frames = 5; + + nam::Conv1x1 conv(bottleneck, 2, /*bias=*/true); + // Random-ish weights for 4->2 with bias + std::vector weights = { + 1.0f, 0.5f, -1.0f, 0.5f, 2.0f, -0.5f, 0.0f, 1.5f, // 2x4 weights (column-major) + 3.0f, -2.0f // biases + }; + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix.setRandom(); + + // Reference: copy topRows to contiguous matrix, process + Eigen::MatrixXf contiguous_input = full_matrix.topRows(bottleneck).eval(); + conv.process_(contiguous_input, num_frames); + Eigen::MatrixXf expected = conv.GetOutput().leftCols(num_frames).eval(); + + // Test: process non-contiguous topRows directly + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& actual = conv.GetOutput(); + + for (int c = 0; c < 2; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(actual(c, f) - expected(c, f)) < 1e-5f); +} + +// ============================================================ +// FiLM::Process() with non-contiguous input (topRows) +// ============================================================ + +void test_film_process_toprows_with_shift() +{ + const int condition_dim = 2; + const int input_dim = 3; + const int total_rows = 6; // 2x input_dim, simulating gated _z matrix + + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + film.SetMaxBufferSize(64); + + // Configure Conv1x1 with zero weights, fixed biases for scale/shift + std::vector weights((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale[0] + weights[bias_offset + 1] = -1.0f; // scale[1] + weights[bias_offset + 2] = 0.5f; // scale[2] + weights[bias_offset + 3] = 10.0f; // shift[0] + weights[bias_offset + 4] = -5.0f; // shift[1] + weights[bias_offset + 5] = 3.0f; // shift[2] + auto it = weights.begin(); + film.set_weights_(it); + + const int num_frames = 4; + + // Create a wider matrix and pass topRows as input + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c + 1) * (f + 1); + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + // Process with non-contiguous topRows(input_dim) + auto top_block = full_matrix.topRows(input_dim); + assert(top_block.outerStride() == total_rows); // Verify non-contiguous + + film.Process(top_block, condition, num_frames); + const auto& output = film.GetOutput(); + + const float scales[] = {2.0f, -1.0f, 0.5f}; + const float shifts[] = {10.0f, -5.0f, 3.0f}; + for (int c = 0; c < input_dim; c++) + { + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f) * scales[c] + shifts[c]; + assert(std::abs(output(c, f) - expected) < 1e-5f); + } + } +} + +void test_film_process_toprows_scale_only() +{ + const int condition_dim = 2; + const int input_dim = 4; + const int total_rows = 8; + + nam::FiLM film(condition_dim, input_dim, /*shift=*/false); + film.SetMaxBufferSize(64); + + std::vector weights(input_dim * condition_dim + input_dim, 0.0f); + const int bias_offset = input_dim * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = 3.0f; + weights[bias_offset + 2] = -1.0f; + weights[bias_offset + 3] = 0.5f; + auto it = weights.begin(); + film.set_weights_(it); + + const int num_frames = 3; + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix.setRandom(); + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + film.Process(full_matrix.topRows(input_dim), condition, num_frames); + const auto& output = film.GetOutput(); + + const float scales[] = {2.0f, 3.0f, -1.0f, 0.5f}; + for (int c = 0; c < input_dim; c++) + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f) * scales[c]; + assert(std::abs(output(c, f) - expected) < 1e-5f); + } +} + +// Verify FiLM with topRows matches result from contiguous copy +void test_film_toprows_matches_contiguous() +{ + const int condition_dim = 2; + const int input_dim = 3; + const int total_rows = 6; + const int num_frames = 4; + + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + film.SetMaxBufferSize(64); + + std::vector weights((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = -1.0f; + weights[bias_offset + 2] = 0.5f; + weights[bias_offset + 3] = 10.0f; + weights[bias_offset + 4] = -5.0f; + weights[bias_offset + 5] = 3.0f; + auto it = weights.begin(); + film.set_weights_(it); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix.setRandom(); + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + // Reference: contiguous copy + Eigen::MatrixXf contiguous = full_matrix.topRows(input_dim).eval(); + film.Process(contiguous, condition, num_frames); + Eigen::MatrixXf expected = film.GetOutput().leftCols(num_frames).eval(); + + // Test: non-contiguous topRows + film.Process(full_matrix.topRows(input_dim), condition, num_frames); + const auto& actual = film.GetOutput(); + + for (int c = 0; c < input_dim; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(actual(c, f) - expected(c, f)) < 1e-6f); +} + +// ============================================================ +// FiLM::Process_() (in-place) with non-contiguous input (topRows) +// ============================================================ + +void test_film_process_inplace_toprows() +{ + const int condition_dim = 2; + const int input_dim = 3; + const int total_rows = 6; + const int num_frames = 4; + + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + film.SetMaxBufferSize(64); + + std::vector weights((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = -1.0f; + weights[bias_offset + 2] = 0.5f; + weights[bias_offset + 3] = 10.0f; + weights[bias_offset + 4] = -5.0f; + weights[bias_offset + 5] = 3.0f; + auto it = weights.begin(); + film.set_weights_(it); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c + 1) * (f + 1); + Eigen::MatrixXf original = full_matrix; + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + // In-place process on topRows block + film.Process_(full_matrix.topRows(input_dim), condition, num_frames); + + const float scales[] = {2.0f, -1.0f, 0.5f}; + const float shifts[] = {10.0f, -5.0f, 3.0f}; + for (int c = 0; c < input_dim; c++) + for (int f = 0; f < num_frames; f++) + { + const float expected = original(c, f) * scales[c] + shifts[c]; + assert(std::abs(full_matrix(c, f) - expected) < 1e-5f); + } + + // Bottom rows should be unchanged + for (int c = input_dim; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_matrix(c, f) - original(c, f)) < 1e-6f); +} + +// ============================================================ +// GatingActivation with non-contiguous output (topRows) +// ============================================================ + +// Helper to create a non-owning shared_ptr for stack-allocated activations in tests +template +static nam::activations::Activation::Ptr make_test_ptr(T& activation) +{ + return nam::activations::Activation::Ptr(&activation, [](nam::activations::Activation*) {}); +} + +void test_gating_output_toprows() +{ + // Simulate wavenet.cpp pattern: + // input_block = _z.leftCols(num_frames) (contiguous, 2*bottleneck rows) + // output_block = _z.topRows(bottleneck).leftCols(num_frames) (non-contiguous output) + const int bottleneck = 3; + const int total_rows = 2 * bottleneck; + const int num_frames = 4; + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), bottleneck); + + // Input: contiguous (2*bottleneck x num_frames) + Eigen::MatrixXf input(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + input(c, f) = (float)(c + 1) * 0.5f; + + // Output goes into topRows of a larger matrix + Eigen::MatrixXf output_matrix(total_rows, num_frames); + output_matrix.setConstant(-999.0f); // poison + auto output_block = output_matrix.topRows(bottleneck).leftCols(num_frames); + + gating.apply(input, output_block); + + // Verify output values are correct (not reading poison from wrong stride) + for (int f = 0; f < num_frames; f++) + { + for (int c = 0; c < bottleneck; c++) + { + const float input_val = input(c, f); // identity activation + const float gate_val = 1.0f / (1.0f + expf(-input(c + bottleneck, f))); // sigmoid + const float expected = input_val * gate_val; + assert(std::abs(output_matrix(c, f) - expected) < 1e-5f); + } + } + + // Bottom rows should still be poison (untouched) + for (int c = bottleneck; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(output_matrix(c, f) - (-999.0f)) < 1e-6f); +} + +void test_gating_toprows_matches_contiguous() +{ + const int bottleneck = 2; + const int total_rows = 2 * bottleneck; + const int num_frames = 3; + + nam::activations::ActivationReLU relu_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating(make_test_ptr(relu_act), make_test_ptr(sigmoid_act), bottleneck); + + Eigen::MatrixXf input(total_rows, num_frames); + input.setRandom(); + + // Reference: contiguous output + Eigen::MatrixXf contiguous_output(bottleneck, num_frames); + gating.apply(input, contiguous_output); + + // Test: non-contiguous output (topRows of larger matrix) + Eigen::MatrixXf full_output(total_rows, num_frames); + full_output.setZero(); + auto output_block = full_output.topRows(bottleneck).leftCols(num_frames); + gating.apply(input, output_block); + + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_output(c, f) - contiguous_output(c, f)) < 1e-6f); +} + +// ============================================================ +// BlendingActivation with non-contiguous output (topRows) +// ============================================================ + +void test_blending_output_toprows() +{ + const int bottleneck = 2; + const int total_rows = 2 * bottleneck; + const int num_frames = 3; + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::BlendingActivation blending( + make_test_ptr(identity_act), make_test_ptr(sigmoid_act), bottleneck); + + Eigen::MatrixXf input(total_rows, num_frames); + input.setRandom(); + + // Reference: contiguous output + Eigen::MatrixXf contiguous_output(bottleneck, num_frames); + blending.apply(input, contiguous_output); + + // Test: non-contiguous output (topRows of larger matrix) + Eigen::MatrixXf full_output(total_rows, num_frames); + full_output.setConstant(-999.0f); + auto output_block = full_output.topRows(bottleneck).leftCols(num_frames); + blending.apply(input, output_block); + + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_output(c, f) - contiguous_output(c, f)) < 1e-5f); + + // Bottom rows should be untouched + for (int c = bottleneck; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_output(c, f) - (-999.0f)) < 1e-6f); +} + +} // namespace test_noncontiguous_blocks