diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 62f2a7e..22874fe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,3 +41,26 @@ jobs: ./build/tools/run_tests ./build/tools/benchmodel ./example_models/wavenet.nam ./build/tools/benchmodel ./example_models/lstm.nam + + build-ubuntu-inline-gemm: + name: Build Ubuntu (Inline GEMM) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3.3.0 + with: + submodules: recursive + + - name: Build Tools + working-directory: ${{github.workspace}}/build + env: + CXX: clang++ + run: | + cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-DNAM_USE_INLINE_GEMM" + cmake --build . -j4 + + - name: Run tests + working-directory: ${{github.workspace}} + run: | + ./build/tools/run_tests + ./build/tools/benchmodel ./example_models/wavenet.nam + ./build/tools/benchmodel ./example_models/lstm.nam diff --git a/NAM/activations.cpp b/NAM/activations.cpp index a476520..3e0bc94 100644 --- a/NAM/activations.cpp +++ b/NAM/activations.cpp @@ -35,8 +35,10 @@ std::unordered_map nam::activati {"PReLU", make_singleton_ptr(_PRELU)}, {"Softsign", make_singleton_ptr(_SOFTSIGN)}}; +// Variables to hold previous instances of activations when we replace them with fast or LUT implementations nam::activations::Activation::Ptr tanh_bak = nullptr; nam::activations::Activation::Ptr sigmoid_bak = nullptr; +nam::activations::Activation::Ptr silu_bak = nullptr; nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const std::string name) { @@ -197,9 +199,14 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m fn = sigmoid; sigmoid_bak = _activations["Sigmoid"]; } + else if (function_name == "SiLU") + { + fn = swish; + silu_bak = _activations["SiLU"]; + } else { - throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid"); + throw std::runtime_error("Tried to enable LUT for a function other than Tanh, Sigmoid, or SiLU"); } _activations[function_name] = std::make_shared(min, max, n_points, fn); } @@ -214,8 +221,12 @@ void nam::activations::Activation::disable_lut(std::string function_name) { _activations["Sigmoid"] = sigmoid_bak; } + else if (function_name == "SiLU") + { + _activations["SiLU"] = silu_bak; + } else { - throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid"); + throw std::runtime_error("Tried to disable LUT for a function other than Tanh, Sigmoid, or SiLU"); } } diff --git a/NAM/activations.h b/NAM/activations.h index a05c456..1ba789c 100644 --- a/NAM/activations.h +++ b/NAM/activations.h @@ -119,18 +119,12 @@ inline float swish(float x) inline float hardswish(float x) { - if (x <= -3.0) - { - return 0; - } - else if (x >= 3.0) - { - return x; - } - else - { - return x * (x + 3.0) / 6.0; - } + // Branchless implementation using clamp + // hardswish(x) = x * relu6(x + 3) / 6 + // = x * clamp(x + 3, 0, 6) / 6 + const float t = x + 3.0f; + const float clamped = t < 0.0f ? 0.0f : (t > 6.0f ? 6.0f : t); + return x * clamped * (1.0f / 6.0f); } inline float softsign(float x) @@ -147,9 +141,18 @@ class Activation Activation() = default; virtual ~Activation() = default; virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); } - virtual void apply(Eigen::Block block) { apply(block.data(), block.rows() * block.cols()); } + virtual void apply(Eigen::Block block) + { + // Block must be contiguous in memory (outerStride == rows) for flat data() access. + // Non-contiguous blocks (e.g. topRows() of a wider matrix) would read/write wrong elements. + assert(block.outerStride() == block.rows()); + apply(block.data(), block.rows() * block.cols()); + } virtual void apply(Eigen::Block block) { + // Inner-panel blocks (e.g. leftCols()) are always contiguous for column-major matrices, + // but assert anyway for safety. + assert(block.outerStride() == block.rows()); apply(block.data(), block.rows() * block.cols()); } virtual void apply(float* data, long size) = 0; @@ -244,9 +247,7 @@ class ActivationReLU : public Activation void apply(float* data, long size) override { for (long pos = 0; pos < size; pos++) - { data[pos] = relu(data[pos]); - } } }; @@ -336,9 +337,7 @@ class ActivationSigmoid : public Activation void apply(float* data, long size) override { for (long pos = 0; pos < size; pos++) - { data[pos] = sigmoid(data[pos]); - } } }; @@ -348,9 +347,7 @@ class ActivationSwish : public Activation void apply(float* data, long size) override { for (long pos = 0; pos < size; pos++) - { data[pos] = swish(data[pos]); - } } }; @@ -360,9 +357,7 @@ class ActivationHardSwish : public Activation void apply(float* data, long size) override { for (long pos = 0; pos < size; pos++) - { data[pos] = hardswish(data[pos]); - } } }; @@ -372,9 +367,7 @@ class ActivationSoftsign : public Activation void apply(float* data, long size) override { for (long pos = 0; pos < size; pos++) - { data[pos] = softsign(data[pos]); - } } }; @@ -400,8 +393,8 @@ class FastLUTActivation : public Activation // Fast lookup with linear interpolation inline float lookup(float x) const { - // Clamp input to range - x = std::clamp(x, min_x_, max_x_); + // Clamp input to range (inline to avoid header dependency) + x = x < min_x_ ? min_x_ : (x > max_x_ ? max_x_ : x); // Calculate float index float f_idx = (x - min_x_) * inv_step_; diff --git a/NAM/conv1d.cpp b/NAM/conv1d.cpp index 9bbbc02..e14c11c 100644 --- a/NAM/conv1d.cpp +++ b/NAM/conv1d.cpp @@ -1,4 +1,5 @@ #include "conv1d.h" +#include #include namespace nam @@ -146,8 +147,8 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) // Write input to ring buffer _input_buffer.Write(input, num_frames); - // Zero output before processing - _output.leftCols(num_frames).setZero(); + // Note: setZero is deferred - only called for paths that need it (those using +=) + // Fused kernel paths use direct assignment (=) and skip setZero // Process from ring buffer with dilation lookback // After Write(), data is at positions [_write_pos, _write_pos+num_frames-1] @@ -156,10 +157,78 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) if (this->_is_depthwise) { + // Depthwise convolution uses += accumulation, so needs setZero + _output.leftCols(num_frames).setZero(); + // Depthwise convolution: use efficient element-wise multiplication // Each channel is processed independently with a single weight per kernel tap. // output[c, t] = sum_k(weight[k, c] * input[c, t - k*dilation]) const size_t kernel_size = this->_depthwise_weight.size(); +#ifdef NAM_USE_INLINE_GEMM + const int channels = this->_channels; + float* __restrict__ output_ptr = _output.data(); + + for (size_t k = 0; k < kernel_size; k++) + { + const long offset = this->_dilation * (k + 1 - (long)kernel_size); + const long lookback = -offset; + auto input_block = _input_buffer.Read(num_frames, lookback); + const float* __restrict__ input_ptr = input_block.data(); + const float* __restrict__ weight_ptr = this->_depthwise_weight[k].data(); + + // Specialized paths for common channel counts + if (channels == 4) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + output_ptr[off] += w0 * input_ptr[off]; + output_ptr[off + 1] += w1 * input_ptr[off + 1]; + output_ptr[off + 2] += w2 * input_ptr[off + 2]; + output_ptr[off + 3] += w3 * input_ptr[off + 3]; + } + } + else if (channels == 8) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2], w3 = weight_ptr[3]; + const float w4 = weight_ptr[4], w5 = weight_ptr[5], w6 = weight_ptr[6], w7 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 8; + output_ptr[off] += w0 * input_ptr[off]; + output_ptr[off + 1] += w1 * input_ptr[off + 1]; + output_ptr[off + 2] += w2 * input_ptr[off + 2]; + output_ptr[off + 3] += w3 * input_ptr[off + 3]; + output_ptr[off + 4] += w4 * input_ptr[off + 4]; + output_ptr[off + 5] += w5 * input_ptr[off + 5]; + output_ptr[off + 6] += w6 * input_ptr[off + 6]; + output_ptr[off + 7] += w7 * input_ptr[off + 7]; + } + } + else + { + // General depthwise path with loop unrolling + for (int f = 0; f < num_frames; f++) + { + const int off = f * channels; + int c = 0; + for (; c + 3 < channels; c += 4) + { + output_ptr[off + c] += weight_ptr[c] * input_ptr[off + c]; + output_ptr[off + c + 1] += weight_ptr[c + 1] * input_ptr[off + c + 1]; + output_ptr[off + c + 2] += weight_ptr[c + 2] * input_ptr[off + c + 2]; + output_ptr[off + c + 3] += weight_ptr[c + 3] * input_ptr[off + c + 3]; + } + for (; c < channels; c++) + { + output_ptr[off + c] += weight_ptr[c] * input_ptr[off + c]; + } + } + } + } +#else for (size_t k = 0; k < kernel_size; k++) { const long offset = this->_dilation * (k + 1 - (long)kernel_size); @@ -169,9 +238,326 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) _output.leftCols(num_frames).noalias() += this->_depthwise_weight[k].asDiagonal() * input_block.leftCols(num_frames); } +#endif } else { +#ifdef NAM_USE_INLINE_GEMM + // Hand-optimized inline GEMM for small matrices + // For output(out_ch, frames) += weight(out_ch, in_ch) * input(in_ch, frames) + // + // Column-major layout means: + // output(o, f) is at output_ptr[f * out_ch + o] + // weight(o, i) is at weight_ptr[i * out_ch + o] + // input(i, f) is at input_ptr[f * in_ch + i] + const int out_ch = (int)get_out_channels(); + const int in_ch = (int)get_in_channels(); + const size_t kernel_size = this->_weight.size(); + const size_t weight_matrix_size = out_ch * in_ch; + + // Fused kernel optimization for kernel_size=3 + // Instead of 3 separate passes over output, fuse into single pass + if (kernel_size == 3 && out_ch == 4 && in_ch == 4) + { + // Fused 4x4 kernel_size=3: read all 3 input blocks and compute in one pass + const long dil = this->_dilation; + auto in0 = _input_buffer.Read(num_frames, 2 * dil); // oldest (k=0) + auto in1 = _input_buffer.Read(num_frames, dil); // middle (k=1) + auto in2 = _input_buffer.Read(num_frames, 0); // newest (k=2) + + const float* __restrict__ in0_ptr = in0.data(); + const float* __restrict__ in1_ptr = in1.data(); + const float* __restrict__ in2_ptr = in2.data(); + float* __restrict__ output_ptr = _output.data(); + + // Get weight pointers for all 3 taps + const size_t wsize = 16; // 4x4 + const float* __restrict__ w0 = this->_weight[0].data(); + const float* __restrict__ w1 = this->_weight[1].data(); + const float* __restrict__ w2 = this->_weight[2].data(); + + // Cache all weights in registers (48 floats for 3 x 4x4 matrices) + const float w0_00 = w0[0], w0_10 = w0[1], w0_20 = w0[2], w0_30 = w0[3]; + const float w0_01 = w0[4], w0_11 = w0[5], w0_21 = w0[6], w0_31 = w0[7]; + const float w0_02 = w0[8], w0_12 = w0[9], w0_22 = w0[10], w0_32 = w0[11]; + const float w0_03 = w0[12], w0_13 = w0[13], w0_23 = w0[14], w0_33 = w0[15]; + + const float w1_00 = w1[0], w1_10 = w1[1], w1_20 = w1[2], w1_30 = w1[3]; + const float w1_01 = w1[4], w1_11 = w1[5], w1_21 = w1[6], w1_31 = w1[7]; + const float w1_02 = w1[8], w1_12 = w1[9], w1_22 = w1[10], w1_32 = w1[11]; + const float w1_03 = w1[12], w1_13 = w1[13], w1_23 = w1[14], w1_33 = w1[15]; + + const float w2_00 = w2[0], w2_10 = w2[1], w2_20 = w2[2], w2_30 = w2[3]; + const float w2_01 = w2[4], w2_11 = w2[5], w2_21 = w2[6], w2_31 = w2[7]; + const float w2_02 = w2[8], w2_12 = w2[9], w2_22 = w2[10], w2_32 = w2[11]; + const float w2_03 = w2[12], w2_13 = w2[13], w2_23 = w2[14], w2_33 = w2[15]; + + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + // Load inputs from all 3 taps + const float i0_0 = in0_ptr[off], i0_1 = in0_ptr[off + 1], i0_2 = in0_ptr[off + 2], i0_3 = in0_ptr[off + 3]; + const float i1_0 = in1_ptr[off], i1_1 = in1_ptr[off + 1], i1_2 = in1_ptr[off + 2], i1_3 = in1_ptr[off + 3]; + const float i2_0 = in2_ptr[off], i2_1 = in2_ptr[off + 1], i2_2 = in2_ptr[off + 2], i2_3 = in2_ptr[off + 3]; + + // Compute output = W0*in0 + W1*in1 + W2*in2 (fused, output was zeroed) + output_ptr[off] = (w0_00 * i0_0 + w0_01 * i0_1 + w0_02 * i0_2 + w0_03 * i0_3) + + (w1_00 * i1_0 + w1_01 * i1_1 + w1_02 * i1_2 + w1_03 * i1_3) + + (w2_00 * i2_0 + w2_01 * i2_1 + w2_02 * i2_2 + w2_03 * i2_3); + output_ptr[off + 1] = (w0_10 * i0_0 + w0_11 * i0_1 + w0_12 * i0_2 + w0_13 * i0_3) + + (w1_10 * i1_0 + w1_11 * i1_1 + w1_12 * i1_2 + w1_13 * i1_3) + + (w2_10 * i2_0 + w2_11 * i2_1 + w2_12 * i2_2 + w2_13 * i2_3); + output_ptr[off + 2] = (w0_20 * i0_0 + w0_21 * i0_1 + w0_22 * i0_2 + w0_23 * i0_3) + + (w1_20 * i1_0 + w1_21 * i1_1 + w1_22 * i1_2 + w1_23 * i1_3) + + (w2_20 * i2_0 + w2_21 * i2_1 + w2_22 * i2_2 + w2_23 * i2_3); + output_ptr[off + 3] = (w0_30 * i0_0 + w0_31 * i0_1 + w0_32 * i0_2 + w0_33 * i0_3) + + (w1_30 * i1_0 + w1_31 * i1_1 + w1_32 * i1_2 + w1_33 * i1_3) + + (w2_30 * i2_0 + w2_31 * i2_1 + w2_32 * i2_2 + w2_33 * i2_3); + } + } + else if (kernel_size == 3 && out_ch == 2 && in_ch == 2) + { + // Fused 2x2 kernel_size=3: read all 3 input blocks and compute in one pass + const long dil = this->_dilation; + auto in0 = _input_buffer.Read(num_frames, 2 * dil); + auto in1 = _input_buffer.Read(num_frames, dil); + auto in2 = _input_buffer.Read(num_frames, 0); + + const float* __restrict__ in0_ptr = in0.data(); + const float* __restrict__ in1_ptr = in1.data(); + const float* __restrict__ in2_ptr = in2.data(); + float* __restrict__ output_ptr = _output.data(); + + const float* __restrict__ w0 = this->_weight[0].data(); + const float* __restrict__ w1 = this->_weight[1].data(); + const float* __restrict__ w2 = this->_weight[2].data(); + + // Cache weights (12 floats total) + const float w0_00 = w0[0], w0_10 = w0[1], w0_01 = w0[2], w0_11 = w0[3]; + const float w1_00 = w1[0], w1_10 = w1[1], w1_01 = w1[2], w1_11 = w1[3]; + const float w2_00 = w2[0], w2_10 = w2[1], w2_01 = w2[2], w2_11 = w2[3]; + + for (int f = 0; f < num_frames; f++) + { + const int off = f * 2; + const float i0_0 = in0_ptr[off], i0_1 = in0_ptr[off + 1]; + const float i1_0 = in1_ptr[off], i1_1 = in1_ptr[off + 1]; + const float i2_0 = in2_ptr[off], i2_1 = in2_ptr[off + 1]; + + output_ptr[off] = (w0_00 * i0_0 + w0_01 * i0_1) + (w1_00 * i1_0 + w1_01 * i1_1) + (w2_00 * i2_0 + w2_01 * i2_1); + output_ptr[off + 1] = + (w0_10 * i0_0 + w0_11 * i0_1) + (w1_10 * i1_0 + w1_11 * i1_1) + (w2_10 * i2_0 + w2_11 * i2_1); + } + } + else + { + // General inline GEMM path uses += accumulation, so needs setZero + _output.leftCols(num_frames).setZero(); + + // General inline GEMM path for other configurations + for (size_t k = 0; k < kernel_size; k++) + { + const long offset = this->_dilation * (k + 1 - (long)kernel_size); + const long lookback = -offset; + auto input_block = _input_buffer.Read(num_frames, lookback); + + const float* __restrict__ input_ptr = input_block.data(); + const float* __restrict__ weight_ptr = this->_weight[k].data(); + float* __restrict__ output_ptr = _output.data(); + + // Specialized fully-unrolled paths for common small channel counts + // These avoid all loop overhead for the tiny matrices in NAM models + if (out_ch == 2 && in_ch == 2) + { + // 2x2 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 2]; + const float i1 = input_ptr[f * 2 + 1]; + output_ptr[f * 2] += w00 * i0 + w01 * i1; + output_ptr[f * 2 + 1] += w10 * i0 + w11 * i1; + } + } + else if (out_ch == 2 && in_ch == 4) + { + // 2x4 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + const float w02 = weight_ptr[4], w12 = weight_ptr[5]; + const float w03 = weight_ptr[6], w13 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 4]; + const float i1 = input_ptr[f * 4 + 1]; + const float i2 = input_ptr[f * 4 + 2]; + const float i3 = input_ptr[f * 4 + 3]; + output_ptr[f * 2] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 2 + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + } + } + else if (out_ch == 4 && in_ch == 1) + { + // 4x1 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + output_ptr[f * 4] += w0 * in_val; + output_ptr[f * 4 + 1] += w1 * in_val; + output_ptr[f * 4 + 2] += w2 * in_val; + output_ptr[f * 4 + 3] += w3 * in_val; + } + } + else if (out_ch == 4 && in_ch == 4) + { + // 4x4 fully unrolled - cache weights in registers + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; + for (int f = 0; f < num_frames; f++) + { + const int in_off = f * 4; + const int out_off = f * 4; + const float i0 = input_ptr[in_off]; + const float i1 = input_ptr[in_off + 1]; + const float i2 = input_ptr[in_off + 2]; + const float i3 = input_ptr[in_off + 3]; + output_ptr[out_off] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[out_off + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[out_off + 2] += w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[out_off + 3] += w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + } + } + else if (out_ch == 3 && in_ch == 1) + { + // 3x1 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + output_ptr[f * 3] += w0 * in_val; + output_ptr[f * 3 + 1] += w1 * in_val; + output_ptr[f * 3 + 2] += w2 * in_val; + } + } + else if (out_ch == 3 && in_ch == 3) + { + // 3x3 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2]; + const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5]; + const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 3; + const float i0 = input_ptr[off]; + const float i1 = input_ptr[off + 1]; + const float i2 = input_ptr[off + 2]; + output_ptr[off] += w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[off + 1] += w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[off + 2] += w20 * i0 + w21 * i1 + w22 * i2; + } + } + else if (out_ch == 4 && in_ch == 3) + { + // 4x3 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 3]; + const float i1 = input_ptr[f * 3 + 1]; + const float i2 = input_ptr[f * 3 + 2]; + output_ptr[f * 4] += w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[f * 4 + 1] += w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[f * 4 + 2] += w20 * i0 + w21 * i1 + w22 * i2; + output_ptr[f * 4 + 3] += w30 * i0 + w31 * i1 + w32 * i2; + } + } + else if (out_ch == 3 && in_ch == 4) + { + // 3x4 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2]; + const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5]; + const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; + const float w03 = weight_ptr[9], w13 = weight_ptr[10], w23 = weight_ptr[11]; + for (int f = 0; f < num_frames; f++) + { + const float i0 = input_ptr[f * 4]; + const float i1 = input_ptr[f * 4 + 1]; + const float i2 = input_ptr[f * 4 + 2]; + const float i3 = input_ptr[f * 4 + 3]; + output_ptr[f * 3] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 3 + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[f * 3 + 2] += w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + } + } + else if (out_ch == 6 && in_ch == 1) + { + // 6x1 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2]; + const float w3 = weight_ptr[3], w4 = weight_ptr[4], w5 = weight_ptr[5]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f]; + const int off = f * 6; + output_ptr[off] += w0 * in_val; + output_ptr[off + 1] += w1 * in_val; + output_ptr[off + 2] += w2 * in_val; + output_ptr[off + 3] += w3 * in_val; + output_ptr[off + 4] += w4 * in_val; + output_ptr[off + 5] += w5 * in_val; + } + } + else if (out_ch == 6 && in_ch == 6) + { + // 6x6 - unroll weights, loop over frames + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 6; + float* __restrict__ out_col = output_ptr + f * 6; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; + const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; + for (int o = 0; o < 6; o++) + { + out_col[o] += weight_ptr[o] * i0 + weight_ptr[6 + o] * i1 + weight_ptr[12 + o] * i2 + + weight_ptr[18 + o] * i3 + weight_ptr[24 + o] * i4 + weight_ptr[30 + o] * i5; + } + } + } + else if (out_ch == 8 && in_ch == 8) + { + // 8x8 - unroll weights, loop over frames + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * 8; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; + for (int o = 0; o < 8; o++) + { + out_col[o] += weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + + weight_ptr[24 + o] * i3 + weight_ptr[32 + o] * i4 + weight_ptr[40 + o] * i5 + + weight_ptr[48 + o] * i6 + weight_ptr[56 + o] * i7; + } + } + } + else + { + // Fall back to Eigen for larger matrices where it's more efficient + _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; + } + } + } // end else (general GEMM path) +#else + // Eigen fallback uses += accumulation, so needs setZero + _output.leftCols(num_frames).setZero(); + + // Eigen fallback for non-ARM platforms // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal), // so we can use a single GEMM for all cases. A more advanced implementation could store // compact per-group weight matrices and loop over groups, but at typical model sizes @@ -184,12 +570,94 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) auto input_block = _input_buffer.Read(num_frames, lookback); _output.leftCols(num_frames).noalias() += this->_weight[k] * input_block; } +#endif } // Add bias if present if (this->_bias.size() > 0) { +#ifdef NAM_USE_INLINE_GEMM + // Inline bias addition for small channel counts + const int out_ch = (int)get_out_channels(); + float* __restrict__ output_ptr = _output.data(); + const float* __restrict__ bias_ptr = this->_bias.data(); + + if (out_ch == 2) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f * 2] += b0; + output_ptr[f * 2 + 1] += b1; + } + } + else if (out_ch == 3) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f * 3] += b0; + output_ptr[f * 3 + 1] += b1; + output_ptr[f * 3 + 2] += b2; + } + } + else if (out_ch == 4) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + const float b2 = bias_ptr[2], b3 = bias_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + output_ptr[f * 4] += b0; + output_ptr[f * 4 + 1] += b1; + output_ptr[f * 4 + 2] += b2; + output_ptr[f * 4 + 3] += b3; + } + } + else if (out_ch == 6) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2]; + const float b3 = bias_ptr[3], b4 = bias_ptr[4], b5 = bias_ptr[5]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 6; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + output_ptr[off + 2] += b2; + output_ptr[off + 3] += b3; + output_ptr[off + 4] += b4; + output_ptr[off + 5] += b5; + } + } + else if (out_ch == 8) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1], b2 = bias_ptr[2], b3 = bias_ptr[3]; + const float b4 = bias_ptr[4], b5 = bias_ptr[5], b6 = bias_ptr[6], b7 = bias_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 8; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + output_ptr[off + 2] += b2; + output_ptr[off + 3] += b3; + output_ptr[off + 4] += b4; + output_ptr[off + 5] += b5; + output_ptr[off + 6] += b6; + output_ptr[off + 7] += b7; + } + } + else + { + for (int f = 0; f < num_frames; f++) + { + for (int o = 0; o < out_ch; o++) + { + output_ptr[f * out_ch + o] += bias_ptr[o]; + } + } + } +#else _output.leftCols(num_frames).colwise() += this->_bias; +#endif } // Advance ring buffer write pointer after processing diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 05dab09..4ba90c6 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -1,5 +1,6 @@ #include // std::max_element #include // pow, tanh, expf +#include #include #include #include @@ -453,10 +454,256 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons } else { +#ifdef NAM_USE_INLINE_GEMM + // Hand-optimized GEMM for small matrices (1x1 convolution) + // output(out_ch, frames) = weight(out_ch, in_ch) * input(in_ch, frames) + const int out_ch = (int)get_out_channels(); + const int in_ch = (int)get_in_channels(); + const float* __restrict__ input_ptr = input.data(); + const float* __restrict__ weight_ptr = this->_weight.data(); + float* __restrict__ output_ptr = _output.data(); + // Use outerStride() instead of in_ch to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int in_stride = (int)input.outerStride(); + + // Specialized paths for common small sizes + if (out_ch == 2 && in_ch == 1) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f * in_stride]; + output_ptr[f * 2] = w0 * in_val; + output_ptr[f * 2 + 1] = w1 * in_val; + } + } + else if (out_ch == 4 && in_ch == 1) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float in_val = input_ptr[f * in_stride]; + output_ptr[f * 4] = w0 * in_val; + output_ptr[f * 4 + 1] = w1 * in_val; + output_ptr[f * 4 + 2] = w2 * in_val; + output_ptr[f * 4 + 3] = w3 * in_val; + } + } + else if (out_ch == 1 && in_ch == 2) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + output_ptr[f] = w0 * in_col[0] + w1 * in_col[1]; + } + } + else if (out_ch == 2 && in_ch == 2) + { + // 2x2 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + output_ptr[f * 2] = w00 * i0 + w01 * i1; + output_ptr[f * 2 + 1] = w10 * i0 + w11 * i1; + } + } + else if (out_ch == 2 && in_ch == 4) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1]; + const float w01 = weight_ptr[2], w11 = weight_ptr[3]; + const float w02 = weight_ptr[4], w12 = weight_ptr[5]; + const float w03 = weight_ptr[6], w13 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + const float i3 = in_col[3]; + output_ptr[f * 2] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 2 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + } + } + else if (out_ch == 1 && in_ch == 4) + { + const float w0 = weight_ptr[0], w1 = weight_ptr[1]; + const float w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + output_ptr[f] = w0 * in_col[0] + w1 * in_col[1] + + w2 * in_col[2] + w3 * in_col[3]; + } + } + else if (out_ch == 4 && in_ch == 2) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + output_ptr[f * 4] = w00 * i0 + w01 * i1; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1; + } + } + else if (out_ch == 3 && in_ch == 3) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2]; + const float w01 = weight_ptr[3], w11 = weight_ptr[4], w21 = weight_ptr[5]; + const float w02 = weight_ptr[6], w12 = weight_ptr[7], w22 = weight_ptr[8]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + output_ptr[f * 3] = w00 * i0 + w01 * i1 + w02 * i2; + output_ptr[f * 3 + 1] = w10 * i0 + w11 * i1 + w12 * i2; + output_ptr[f * 3 + 2] = w20 * i0 + w21 * i1 + w22 * i2; + } + } + else if (out_ch == 4 && in_ch == 4) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + const float i3 = in_col[3]; + output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + } + } + else if (out_ch == 6 && in_ch == 6) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + float* __restrict__ out_col = output_ptr + f * 6; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; + const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; + for (int o = 0; o < 6; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[6 + o] * i1 + weight_ptr[12 + o] * i2 + + weight_ptr[18 + o] * i3 + weight_ptr[24 + o] * i4 + weight_ptr[30 + o] * i5; + } + } + } + else if (out_ch == 8 && in_ch == 8) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; + for (int o = 0; o < 8; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + weight_ptr[24 + o] * i3 + + weight_ptr[32 + o] * i4 + weight_ptr[40 + o] * i5 + weight_ptr[48 + o] * i6 + weight_ptr[56 + o] * i7; + } + } + } + else if (out_ch == 4 && in_ch == 8) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + float* __restrict__ out_col = output_ptr + f * 4; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + const float i4 = in_col[4], i5 = in_col[5], i6 = in_col[6], i7 = in_col[7]; + for (int o = 0; o < 4; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[4 + o] * i1 + weight_ptr[8 + o] * i2 + weight_ptr[12 + o] * i3 + + weight_ptr[16 + o] * i4 + weight_ptr[20 + o] * i5 + weight_ptr[24 + o] * i6 + weight_ptr[28 + o] * i7; + } + } + } + else if (out_ch == 8 && in_ch == 4) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2], i3 = in_col[3]; + for (int o = 0; o < 8; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + weight_ptr[24 + o] * i3; + } + } + } + else + { + // Fall back to Eigen for larger matrices where it's more efficient + _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); + } +#else // Single GEMM for all cases - block-diagonal zero structure handles grouping _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); +#endif } if (this->_do_bias) + { +#ifdef NAM_USE_INLINE_GEMM + const int out_ch = (int)get_out_channels(); + float* __restrict__ output_ptr = _output.data(); + const float* __restrict__ bias_ptr = this->_bias.data(); + + // Specialized paths for common small channel counts + if (out_ch == 2) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 2; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + } + } + else if (out_ch == 4) + { + const float b0 = bias_ptr[0], b1 = bias_ptr[1]; + const float b2 = bias_ptr[2], b3 = bias_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int off = f * 4; + output_ptr[off] += b0; + output_ptr[off + 1] += b1; + output_ptr[off + 2] += b2; + output_ptr[off + 3] += b3; + } + } + else + { + for (int f = 0; f < num_frames; f++) + { + float* __restrict__ out_col = output_ptr + f * out_ch; + for (int o = 0; o < out_ch; o++) + { + out_col[o] += bias_ptr[o]; + } + } + } +#else _output.leftCols(num_frames).colwise() += this->_bias; +#endif + } } diff --git a/NAM/film.h b/NAM/film.h index f0f86fb..559de13 100644 --- a/NAM/film.h +++ b/NAM/film.h @@ -84,6 +84,65 @@ class FiLM _cond_to_scale_shift.process_(condition, num_frames); const auto& scale_shift = _cond_to_scale_shift.GetOutput(); +#ifdef NAM_USE_INLINE_GEMM + // Optimized inline FiLM operation + const int input_dim = (int)get_input_dim(); + const float* __restrict__ input_ptr = input.data(); + const float* __restrict__ scale_shift_ptr = scale_shift.data(); + float* __restrict__ output_ptr = _output.data(); + const int scale_shift_rows = (int)scale_shift.rows(); + // Use outerStride() instead of rows() to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int input_stride = (int)input.outerStride(); + + if (_do_shift) + { + // scale = top input_dim rows, shift = bottom input_dim rows + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_stride; + const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows; + const float* __restrict__ shift_col = scale_col + input_dim; + float* __restrict__ out_col = output_ptr + f * input_dim; + + int i = 0; + for (; i + 3 < input_dim; i += 4) + { + out_col[i] = in_col[i] * scale_col[i] + shift_col[i]; + out_col[i + 1] = in_col[i + 1] * scale_col[i + 1] + shift_col[i + 1]; + out_col[i + 2] = in_col[i + 2] * scale_col[i + 2] + shift_col[i + 2]; + out_col[i + 3] = in_col[i + 3] * scale_col[i + 3] + shift_col[i + 3]; + } + for (; i < input_dim; i++) + { + out_col[i] = in_col[i] * scale_col[i] + shift_col[i]; + } + } + } + else + { + // scale only + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_stride; + const float* __restrict__ scale_col = scale_shift_ptr + f * scale_shift_rows; + float* __restrict__ out_col = output_ptr + f * input_dim; + + int i = 0; + for (; i + 3 < input_dim; i += 4) + { + out_col[i] = in_col[i] * scale_col[i]; + out_col[i + 1] = in_col[i + 1] * scale_col[i + 1]; + out_col[i + 2] = in_col[i + 2] * scale_col[i + 2]; + out_col[i + 3] = in_col[i + 3] * scale_col[i + 3]; + } + for (; i < input_dim; i++) + { + out_col[i] = in_col[i] * scale_col[i]; + } + } + } +#else const auto scale = scale_shift.topRows(get_input_dim()).leftCols(num_frames); if (_do_shift) { @@ -95,6 +154,7 @@ class FiLM { _output.leftCols(num_frames).array() = input.leftCols(num_frames).array() * scale.array(); } +#endif } /// \brief Process input with conditioning (in-place) diff --git a/NAM/gating_activations.h b/NAM/gating_activations.h index 335a984..0d52298 100644 --- a/NAM/gating_activations.h +++ b/NAM/gating_activations.h @@ -59,14 +59,44 @@ class GatingActivation void apply(const Eigen::MatrixBase& input, Eigen::MatrixBase& output) { // Validate input dimensions (assert for real-time performance) - const int total_channels = 2 * num_channels; - assert(input.rows() == total_channels); + assert(input.rows() == 2 * num_channels); assert(output.rows() == num_channels); assert(output.cols() == input.cols()); - // Process column-by-column to ensure memory contiguity (important for column-major matrices) - // Uses pre-allocated buffers to avoid allocations in the loop (real-time safe) const int num_samples = input.cols(); + +#ifdef NAM_USE_INLINE_GEMM + // Optimized path: direct memory access with activation applied per-element + // Use outerStride() instead of rows() to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int input_stride = (int)input.outerStride(); + const float* __restrict__ input_ptr = input.derived().data(); + float* __restrict__ output_ptr = output.derived().data(); + const int output_stride = (int)output.outerStride(); // Column stride for output + + for (int f = 0; f < num_samples; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_stride; + float* __restrict__ out_col = output_ptr + f * output_stride; + + // Copy input and gating channels to buffers, apply activations, multiply + for (int c = 0; c < num_channels; c++) + { + input_buffer(c, 0) = in_col[c]; + gating_buffer(c, 0) = in_col[c + num_channels]; + } + + input_activation->apply(input_buffer); + gating_activation->apply(gating_buffer); + + // Element-wise multiply and store + for (int c = 0; c < num_channels; c++) + { + out_col[c] = input_buffer(c, 0) * gating_buffer(c, 0); + } + } +#else + // Original Eigen path for (int i = 0; i < num_samples; i++) { // Copy to pre-allocated buffers and apply activations in-place @@ -77,9 +107,9 @@ class GatingActivation gating_activation->apply(gating_buffer); // Element-wise multiplication and store result - // For wavenet compatibility, we assume one-to-one mapping output.block(0, i, num_channels, 1) = input_buffer.array() * gating_buffer.array(); } +#endif } /** @@ -135,14 +165,47 @@ class BlendingActivation void apply(const Eigen::MatrixBase& input, Eigen::MatrixBase& output) { // Validate input dimensions (assert for real-time performance) - const int total_channels = num_channels * 2; // 2*channels in, channels out - assert(input.rows() == total_channels); + assert(input.rows() == num_channels * 2); assert(output.rows() == num_channels); assert(output.cols() == input.cols()); - // Process column-by-column to ensure memory contiguity - // Uses pre-allocated buffers to avoid allocations in the loop (real-time safe) const int num_samples = input.cols(); + +#ifdef NAM_USE_INLINE_GEMM + // Optimized path: direct memory access + // Use outerStride() instead of rows() to correctly handle non-contiguous + // block expressions (e.g. topRows()) where outerStride > rows + const int input_stride = (int)input.outerStride(); + const float* __restrict__ input_ptr = input.derived().data(); + float* __restrict__ output_ptr = output.derived().data(); + const int output_stride = (int)output.outerStride(); // Column stride for output + + for (int f = 0; f < num_samples; f++) + { + const float* __restrict__ in_col = input_ptr + f * input_stride; + float* __restrict__ out_col = output_ptr + f * output_stride; + + // Copy channels to buffers + for (int c = 0; c < num_channels; c++) + { + pre_activation_buffer(c, 0) = in_col[c]; + input_buffer(c, 0) = in_col[c]; + blend_buffer(c, 0) = in_col[c + num_channels]; + } + + // Apply activations + input_activation->apply(input_buffer); + blending_activation->apply(blend_buffer); + + // Weighted blending: alpha * activated + (1 - alpha) * pre_activation + for (int c = 0; c < num_channels; c++) + { + const float alpha = blend_buffer(c, 0); + out_col[c] = alpha * input_buffer(c, 0) + (1.0f - alpha) * pre_activation_buffer(c, 0); + } + } +#else + // Original Eigen path for (int i = 0; i < num_samples; i++) { // Store pre-activation input values in buffer @@ -160,6 +223,7 @@ class BlendingActivation output.block(0, i, num_channels, 1) = blend_buffer.array() * input_buffer.array() + (1.0f - blend_buffer.array()) * pre_activation_buffer.array(); } +#endif } /** diff --git a/NAM/ring_buffer.cpp b/NAM/ring_buffer.cpp index 64518b4..8f0919a 100644 --- a/NAM/ring_buffer.cpp +++ b/NAM/ring_buffer.cpp @@ -35,22 +35,10 @@ void RingBuffer::Write(const Eigen::MatrixXf& input, const int num_frames) if (NeedsRewind(num_frames)) Rewind(); - // Write the input data at the write position - // NOTE: This function assumes that `input` is a full, pre-allocated MatrixXf - // covering the entire valid buffer range. Callers should not pass Block - // expressions across the API boundary; instead, pass the full buffer and - // slice inside the callee. This avoids Eigen evaluating Blocks into - // temporaries (which would allocate) when binding to MatrixXf. - const int channels = _storage.rows(); - const int copy_cols = num_frames; - - for (int col = 0; col < copy_cols; ++col) - { - for (int row = 0; row < channels; ++row) - { - _storage(row, _write_pos + col) = input(row, col); - } - } + // Write the input data at the write position using Eigen block operations + // This is more efficient than element-by-element copy as it allows + // the compiler to vectorize the operation. + _storage.middleCols(_write_pos, num_frames).noalias() = input.leftCols(num_frames); } Eigen::Block RingBuffer::Read(const int num_frames, const long lookback) diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 6eb74a3..d867996 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -125,6 +126,7 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma } this->_z.leftCols(num_frames).noalias() = _conv.GetOutput().leftCols(num_frames) + _input_mixin.GetOutput().leftCols(num_frames); + if (this->_activation_pre_film) { this->_activation_pre_film->Process_(this->_z, condition, num_frames); @@ -207,16 +209,49 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma Eigen::MatrixXf& head1x1_output = this->_head1x1->GetOutput(); this->_head1x1_post_film->Process_(head1x1_output, condition, num_frames); } +#ifdef NAM_USE_INLINE_GEMM + { + // Pure copy - use memcpy + const int total = (int)this->_head1x1->get_out_channels() * num_frames; + std::memcpy(this->_output_head.data(), this->_head1x1->GetOutput().data(), total * sizeof(float)); + } +#else this->_output_head.leftCols(num_frames).noalias() = this->_head1x1->GetOutput().leftCols(num_frames); +#endif } else // No head 1x1 { // (No FiLM) // Store output to head (skip connection: activated conv output) +#ifdef NAM_USE_INLINE_GEMM + if (this->_gating_mode == GatingMode::NONE) + { + // _z has bottleneck rows, data is contiguous - use memcpy + const int total = (int)bottleneck * num_frames; + std::memcpy(this->_output_head.data(), this->_z.data(), total * sizeof(float)); + } + else + { + // _z has 2*bottleneck rows but we only want top bottleneck rows + // Column-major: need to copy column by column with stride + const int out_rows = (int)bottleneck; + const int z_rows = (int)this->_z.rows(); // 2*bottleneck for gated + const float* __restrict__ src = this->_z.data(); + float* __restrict__ dst = this->_output_head.data(); + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ src_col = src + f * z_rows; + float* __restrict__ dst_col = dst + f * out_rows; + for (int r = 0; r < out_rows; r++) + dst_col[r] = src_col[r]; + } + } +#else if (this->_gating_mode == GatingMode::NONE) this->_output_head.leftCols(num_frames).noalias() = this->_z.leftCols(num_frames); else this->_output_head.leftCols(num_frames).noalias() = this->_z.topRows(bottleneck).leftCols(num_frames); +#endif } // Store output to next layer (residual connection: input + layer1x1 output, or just input if layer1x1 inactive) @@ -228,7 +263,15 @@ void nam::wavenet::_Layer::Process(const Eigen::MatrixXf& input, const Eigen::Ma else { // If layer1x1 is inactive, residual connection is just the input (identity) +#ifdef NAM_USE_INLINE_GEMM + { + // Pure copy - use memcpy + const int total = (int)this->get_channels() * num_frames; + std::memcpy(this->_output_next_layer.data(), input.data(), total * sizeof(float)); + } +#else this->_output_next_layer.leftCols(num_frames).noalias() = input.leftCols(num_frames); +#endif } } @@ -290,8 +333,15 @@ void nam::wavenet::_LayerArray::Process(const Eigen::MatrixXf& layer_inputs, con void nam::wavenet::_LayerArray::Process(const Eigen::MatrixXf& layer_inputs, const Eigen::MatrixXf& condition, const Eigen::MatrixXf& head_inputs, const int num_frames) { - // Copy head inputs from previous layer array + // Copy head inputs from previous layer array - use memcpy for pure copy +#ifdef NAM_USE_INLINE_GEMM + { + const int total = (int)this->_head_output_size * num_frames; + std::memcpy(this->_head_inputs.data(), head_inputs.data(), total * sizeof(float)); + } +#else this->_head_inputs.leftCols(num_frames).noalias() = head_inputs.leftCols(num_frames); +#endif ProcessInner(layer_inputs, condition, num_frames); } @@ -323,10 +373,18 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs this->_head_inputs.leftCols(num_frames).noalias() += this->_layers[i].GetOutputHead().leftCols(num_frames); } - // Store output from last layer + // Store output from last layer - use memcpy for pure copy const size_t last_layer = this->_layers.size() - 1; +#ifdef NAM_USE_INLINE_GEMM + { + const int total = (int)this->_get_channels() * num_frames; + std::memcpy(this->_layer_outputs.data(), this->_layers[last_layer].GetOutputNextLayer().data(), + total * sizeof(float)); + } +#else this->_layer_outputs.leftCols(num_frames).noalias() = this->_layers[last_layer].GetOutputNextLayer().leftCols(num_frames); +#endif // Process head rechannel _head_rechannel.process_(this->_head_inputs, num_frames); @@ -558,12 +616,27 @@ void nam::wavenet::WaveNet::process(NAM_SAMPLE** input, NAM_SAMPLE** output, con auto& final_head_outputs = this->_layer_arrays.back().GetHeadOutputs(); assert(final_head_outputs.rows() == out_channels); - for (int ch = 0; ch < out_channels; ch++) + // Optimized output copy with head_scale multiplication + if (out_channels == 1) { + // Single channel: data is contiguous + const float scale = this->_head_scale; + const float* __restrict__ src = final_head_outputs.data(); + NAM_SAMPLE* __restrict__ dst = output[0]; for (int s = 0; s < num_frames; s++) { - const float out = this->_head_scale * final_head_outputs(ch, s); - output[ch][s] = out; + dst[s] = scale * src[s]; + } + } + else + { + // Multi-channel: rows are not contiguous in column-major + for (int ch = 0; ch < out_channels; ch++) + { + for (int s = 0; s < num_frames; s++) + { + output[ch][s] = this->_head_scale * final_head_outputs(ch, s); + } } } } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8118e08..cf5b64f 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -12,6 +12,7 @@ include_directories(tools ${NAM_DEPS_PATH}/nlohmann) add_executable(loadmodel loadmodel.cpp ${NAM_SOURCES}) add_executable(benchmodel benchmodel.cpp ${NAM_SOURCES}) +add_executable(benchmodel_bufsize benchmodel_bufsize.cpp ${NAM_SOURCES}) add_executable(run_tests run_tests.cpp test/allocation_tracking.cpp ${NAM_SOURCES}) # Compile run_tests without optimizations to ensure allocation tracking works correctly # Also ensure assertions are enabled (NDEBUG is not defined) so tests actually run diff --git a/tools/benchmodel.cpp b/tools/benchmodel.cpp index 39c14b0..5e8c45d 100644 --- a/tools/benchmodel.cpp +++ b/tools/benchmodel.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include "NAM/dsp.h" @@ -17,74 +18,85 @@ double outputBuffer[AUDIO_BUFFER_SIZE]; int main(int argc, char* argv[]) { - if (argc > 1) + if (argc < 2) { - const char* modelPath = argv[1]; - - std::cout << "Loading model " << modelPath << "\n"; + std::cerr << "Usage: benchmodel [--no-fast-tanh]\n"; + exit(1); + } - // Turn on fast tanh approximation - nam::activations::Activation::enable_fast_tanh(); + const char* modelPath = argv[1]; - std::unique_ptr model; + // Check for --no-fast-tanh flag + bool useFastTanh = true; + for (int i = 2; i < argc; i++) + { + if (std::strcmp(argv[i], "--no-fast-tanh") == 0) + useFastTanh = false; + } - model.reset(); - model = nam::get_dsp(std::filesystem::path(modelPath)); + if (useFastTanh) + { + nam::activations::Activation::enable_fast_tanh(); + std::cout << "Fast tanh: enabled\n"; + } + else + { + nam::activations::Activation::disable_fast_tanh(); + std::cout << "Fast tanh: disabled\n"; + } - if (model == nullptr) - { - std::cerr << "Failed to load model\n"; + std::cout << "Loading model " << modelPath << "\n"; - exit(1); - } + std::unique_ptr model; + model = nam::get_dsp(std::filesystem::path(modelPath)); - size_t bufferSize = AUDIO_BUFFER_SIZE; - model->Reset(model->GetExpectedSampleRate(), bufferSize); - size_t numBuffers = (48000 / bufferSize) * 2; + if (model == nullptr) + { + std::cerr << "Failed to load model\n"; + exit(1); + } - // Allocate multi-channel buffers - const int in_channels = model->NumInputChannels(); - const int out_channels = model->NumOutputChannels(); + size_t bufferSize = AUDIO_BUFFER_SIZE; + model->Reset(model->GetExpectedSampleRate(), bufferSize); + size_t numBuffers = (48000 / bufferSize) * 2; - std::vector> inputBuffers(in_channels); - std::vector> outputBuffers(out_channels); - std::vector inputPtrs(in_channels); - std::vector outputPtrs(out_channels); + // Allocate multi-channel buffers + const int in_channels = model->NumInputChannels(); + const int out_channels = model->NumOutputChannels(); - for (int ch = 0; ch < in_channels; ch++) - { - inputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); - inputPtrs[ch] = inputBuffers[ch].data(); - } - for (int ch = 0; ch < out_channels; ch++) - { - outputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); - outputPtrs[ch] = outputBuffers[ch].data(); - } + std::vector> inputBuffers(in_channels); + std::vector> outputBuffers(out_channels); + std::vector inputPtrs(in_channels); + std::vector outputPtrs(out_channels); - std::cout << "Running benchmark\n"; - auto t1 = high_resolution_clock::now(); - for (size_t i = 0; i < numBuffers; i++) - { - model->process(inputPtrs.data(), outputPtrs.data(), AUDIO_BUFFER_SIZE); - } - auto t2 = high_resolution_clock::now(); - std::cout << "Finished\n"; + for (int ch = 0; ch < in_channels; ch++) + { + inputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); + inputPtrs[ch] = inputBuffers[ch].data(); + } + for (int ch = 0; ch < out_channels; ch++) + { + outputBuffers[ch].resize(AUDIO_BUFFER_SIZE, 0.0); + outputPtrs[ch] = outputBuffers[ch].data(); + } + std::cout << "Running benchmark\n"; + auto t1 = high_resolution_clock::now(); + for (size_t i = 0; i < numBuffers; i++) + { + model->process(inputPtrs.data(), outputPtrs.data(), AUDIO_BUFFER_SIZE); + } + auto t2 = high_resolution_clock::now(); + std::cout << "Finished\n"; - /* Getting number of milliseconds as an integer. */ - auto ms_int = duration_cast(t2 - t1); + /* Getting number of milliseconds as an integer. */ + auto ms_int = duration_cast(t2 - t1); - /* Getting number of milliseconds as a double. */ - duration ms_double = t2 - t1; + /* Getting number of milliseconds as a double. */ + duration ms_double = t2 - t1; - std::cout << ms_int.count() << "ms\n"; - std::cout << ms_double.count() << "ms\n"; - } - else - { - std::cerr << "Usage: benchmodel \n"; - } + std::cout << ms_int.count() << "ms\n"; + std::cout << ms_double.count() << "ms\n"; - exit(0); + return 0; } diff --git a/tools/benchmodel_bufsize.cpp b/tools/benchmodel_bufsize.cpp new file mode 100644 index 0000000..355b605 --- /dev/null +++ b/tools/benchmodel_bufsize.cpp @@ -0,0 +1,112 @@ +#include +#include +#include +#include +#include + +#include "NAM/dsp.h" +#include "NAM/get_dsp.h" + +using std::chrono::duration; +using std::chrono::duration_cast; +using std::chrono::high_resolution_clock; +using std::chrono::microseconds; + +/* A version of benchmodel that accepts an arbitrary buffer size between 1 and 4096. + * Useful for testing the effect of smaller/larger buffers on performance. + */ + +int main(int argc, char* argv[]) +{ + if (argc < 3) + { + std::cerr << "Usage: benchmodel_bufsize [num_iterations] [--no-fast-tanh]\n"; + exit(1); + } + + const char* modelPath = argv[1]; + const int bufferSize = std::atoi(argv[2]); + int numIterations = 5; + bool useFastTanh = true; + + for (int i = 3; i < argc; i++) + { + if (std::strcmp(argv[i], "--no-fast-tanh") == 0) + useFastTanh = false; + else + numIterations = std::atoi(argv[i]); + } + + if (bufferSize <= 0 || bufferSize > 4096) + { + std::cerr << "Buffer size must be between 1 and 4096\n"; + exit(1); + } + + if (useFastTanh) + nam::activations::Activation::enable_fast_tanh(); + else + nam::activations::Activation::disable_fast_tanh(); + + std::unique_ptr model; + model = nam::get_dsp(std::filesystem::path(modelPath)); + + if (model == nullptr) + { + std::cerr << "Failed to load model\n"; + exit(1); + } + + model->Reset(model->GetExpectedSampleRate(), bufferSize); + + // Process 2 seconds of audio + const size_t totalSamples = 48000 * 2; + const size_t numBuffers = totalSamples / bufferSize; + + // Allocate multi-channel buffers + const int in_channels = model->NumInputChannels(); + const int out_channels = model->NumOutputChannels(); + + std::vector> inputBuffers(in_channels); + std::vector> outputBuffers(out_channels); + std::vector inputPtrs(in_channels); + std::vector outputPtrs(out_channels); + + for (int ch = 0; ch < in_channels; ch++) + { + inputBuffers[ch].resize(bufferSize, 0.0); + inputPtrs[ch] = inputBuffers[ch].data(); + } + for (int ch = 0; ch < out_channels; ch++) + { + outputBuffers[ch].resize(bufferSize, 0.0); + outputPtrs[ch] = outputBuffers[ch].data(); + } + + // Warm-up run + for (size_t i = 0; i < numBuffers; i++) + { + model->process(inputPtrs.data(), outputPtrs.data(), bufferSize); + } + + // Timed runs + double totalMicroseconds = 0.0; + for (int iter = 0; iter < numIterations; iter++) + { + auto t1 = high_resolution_clock::now(); + for (size_t i = 0; i < numBuffers; i++) + { + model->process(inputPtrs.data(), outputPtrs.data(), bufferSize); + } + auto t2 = high_resolution_clock::now(); + duration us_double = t2 - t1; + totalMicroseconds += us_double.count(); + } + + double avgMicroseconds = totalMicroseconds / numIterations; + + // Output format: buffer_size,avg_microseconds + std::cout << bufferSize << "," << avgMicroseconds << std::endl; + + return 0; +} diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 3027bc0..4080e03 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -26,6 +26,7 @@ #include "test/test_input_buffer_verification.cpp" #include "test/test_lstm.cpp" #include "test/test_wavenet_configurable_gating.cpp" +#include "test/test_noncontiguous_blocks.cpp" #include "test/test_extensible.cpp" int main() @@ -239,6 +240,20 @@ int main() // Configurable gating/blending tests run_configurable_gating_tests(); + // Non-contiguous block correctness tests (outerStride != rows) + test_noncontiguous_blocks::test_conv1x1_process_toprows(); + test_noncontiguous_blocks::test_conv1x1_process_toprows_with_bias(); + test_noncontiguous_blocks::test_conv1x1_process_toprows_2x2(); + test_noncontiguous_blocks::test_conv1x1_process_toprows_4x4(); + test_noncontiguous_blocks::test_conv1x1_toprows_matches_contiguous(); + test_noncontiguous_blocks::test_film_process_toprows_with_shift(); + test_noncontiguous_blocks::test_film_process_toprows_scale_only(); + test_noncontiguous_blocks::test_film_toprows_matches_contiguous(); + test_noncontiguous_blocks::test_film_process_inplace_toprows(); + test_noncontiguous_blocks::test_gating_output_toprows(); + test_noncontiguous_blocks::test_gating_toprows_matches_contiguous(); + test_noncontiguous_blocks::test_blending_output_toprows(); + test_get_dsp::test_gets_input_level(); test_get_dsp::test_gets_output_level(); test_get_dsp::test_null_input_level(); diff --git a/tools/test/test_noncontiguous_blocks.cpp b/tools/test/test_noncontiguous_blocks.cpp new file mode 100644 index 0000000..7044fd1 --- /dev/null +++ b/tools/test/test_noncontiguous_blocks.cpp @@ -0,0 +1,529 @@ +// Tests for correct handling of non-contiguous Eigen block expressions +// (e.g. topRows()) in inline GEMM paths. +// +// When a matrix expression like matrix.topRows(n) is passed to a function +// that accesses raw pointers via .data(), the outerStride() (distance between +// columns) may be larger than rows(). Code that assumes stride == rows will +// read/write wrong memory locations. + +#include +#include +#include +#include +#include + +#include "NAM/activations.h" +#include "NAM/dsp.h" +#include "NAM/film.h" +#include "NAM/gating_activations.h" + +namespace test_noncontiguous_blocks +{ + +// Helper: create identity Conv1x1 weights (identity matrix + zero bias) +static std::vector make_identity_weights(int channels, bool bias) +{ + std::vector weights(channels * channels + (bias ? channels : 0), 0.0f); + // Column-major identity matrix + for (int i = 0; i < channels; i++) + weights[i * channels + i] = 1.0f; + return weights; +} + +// Helper: create scaled Conv1x1 weights (diagonal scale matrix + zero bias) +static std::vector make_scale_weights(int in_ch, int out_ch, float scale, bool bias) +{ + std::vector weights(out_ch * in_ch + (bias ? out_ch : 0), 0.0f); + // Set all weight elements to scale (for simple testing) + for (int o = 0; o < out_ch; o++) + for (int i = 0; i < in_ch; i++) + weights[i * out_ch + o] = (i == o) ? scale : 0.0f; + return weights; +} + +// ============================================================ +// Conv1x1::process_() with non-contiguous input (topRows) +// ============================================================ + +void test_conv1x1_process_toprows() +{ + // Create a matrix with 8 rows, but only pass topRows(4) to Conv1x1 + // This simulates the gated activation path in wavenet.cpp + const int total_rows = 8; + const int bottleneck = 4; + const int num_frames = 3; + + // Create Conv1x1: 4 in -> 4 out, identity weights + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/false); + auto weights = make_identity_weights(bottleneck, false); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + // Create full matrix with known values + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c * 10 + f); + + // Process only topRows(bottleneck) - this has outerStride = 8, rows = 4 + auto top_block = full_matrix.topRows(bottleneck); + assert(top_block.outerStride() == total_rows); // Verify non-contiguous + assert(top_block.rows() == bottleneck); + + conv.process_(top_block, num_frames); + const auto& output = conv.GetOutput(); + + // With identity weights, output should equal the top rows exactly + for (int c = 0; c < bottleneck; c++) + { + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f); + const float actual = output(c, f); + assert(std::abs(actual - expected) < 1e-6f); + } + } +} + +void test_conv1x1_process_toprows_with_bias() +{ + const int total_rows = 6; + const int bottleneck = 3; + const int num_frames = 4; + + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/true); + auto weights = make_identity_weights(bottleneck, false); + // Add bias values + weights.push_back(10.0f); + weights.push_back(20.0f); + weights.push_back(30.0f); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c + 1); + + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& output = conv.GetOutput(); + + const float biases[] = {10.0f, 20.0f, 30.0f}; + for (int c = 0; c < bottleneck; c++) + { + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f) + biases[c]; + assert(std::abs(output(c, f) - expected) < 1e-6f); + } + } +} + +void test_conv1x1_process_toprows_2x2() +{ + // Test specific 2x2 specialized path with non-contiguous input + const int total_rows = 4; // doubled for gating + const int bottleneck = 2; + const int num_frames = 3; + + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/false); + // Weight: [[2, 0], [0, 3]] (column-major: [2, 0, 0, 3]) + std::vector weights = {2.0f, 0.0f, 0.0f, 3.0f}; + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix << 1.0f, 2.0f, 3.0f, // row 0 (top, used) + 4.0f, 5.0f, 6.0f, // row 1 (top, used) + 99.0f, 99.0f, 99.0f, // row 2 (bottom, NOT used) + 99.0f, 99.0f, 99.0f; // row 3 (bottom, NOT used) + + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& output = conv.GetOutput(); + + // output = [[2,0],[0,3]] * topRows(2) + // Frame 0: [2*1, 3*4] = [2, 12] + // Frame 1: [2*2, 3*5] = [4, 15] + // Frame 2: [2*3, 3*6] = [6, 18] + assert(std::abs(output(0, 0) - 2.0f) < 1e-6f); + assert(std::abs(output(1, 0) - 12.0f) < 1e-6f); + assert(std::abs(output(0, 1) - 4.0f) < 1e-6f); + assert(std::abs(output(1, 1) - 15.0f) < 1e-6f); + assert(std::abs(output(0, 2) - 6.0f) < 1e-6f); + assert(std::abs(output(1, 2) - 18.0f) < 1e-6f); +} + +void test_conv1x1_process_toprows_4x4() +{ + // Test specific 4x4 specialized path with non-contiguous input + const int total_rows = 8; + const int bottleneck = 4; + const int num_frames = 2; + + nam::Conv1x1 conv(bottleneck, bottleneck, /*bias=*/false); + auto weights = make_identity_weights(bottleneck, false); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + // Top 4 rows: known values; bottom 4 rows: poison values + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c * 10 + f + 1); + for (int c = bottleneck; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = -999.0f; // poison + + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& output = conv.GetOutput(); + + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(output(c, f) - full_matrix(c, f)) < 1e-6f); +} + +// Verify Conv1x1 with topRows matches result from contiguous copy +void test_conv1x1_toprows_matches_contiguous() +{ + const int total_rows = 8; + const int bottleneck = 4; + const int num_frames = 5; + + nam::Conv1x1 conv(bottleneck, 2, /*bias=*/true); + // Random-ish weights for 4->2 with bias + std::vector weights = { + 1.0f, 0.5f, -1.0f, 0.5f, 2.0f, -0.5f, 0.0f, 1.5f, // 2x4 weights (column-major) + 3.0f, -2.0f // biases + }; + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix.setRandom(); + + // Reference: copy topRows to contiguous matrix, process + Eigen::MatrixXf contiguous_input = full_matrix.topRows(bottleneck).eval(); + conv.process_(contiguous_input, num_frames); + Eigen::MatrixXf expected = conv.GetOutput().leftCols(num_frames).eval(); + + // Test: process non-contiguous topRows directly + conv.process_(full_matrix.topRows(bottleneck), num_frames); + const auto& actual = conv.GetOutput(); + + for (int c = 0; c < 2; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(actual(c, f) - expected(c, f)) < 1e-5f); +} + +// ============================================================ +// FiLM::Process() with non-contiguous input (topRows) +// ============================================================ + +void test_film_process_toprows_with_shift() +{ + const int condition_dim = 2; + const int input_dim = 3; + const int total_rows = 6; // 2x input_dim, simulating gated _z matrix + + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + film.SetMaxBufferSize(64); + + // Configure Conv1x1 with zero weights, fixed biases for scale/shift + std::vector weights((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; // scale[0] + weights[bias_offset + 1] = -1.0f; // scale[1] + weights[bias_offset + 2] = 0.5f; // scale[2] + weights[bias_offset + 3] = 10.0f; // shift[0] + weights[bias_offset + 4] = -5.0f; // shift[1] + weights[bias_offset + 5] = 3.0f; // shift[2] + auto it = weights.begin(); + film.set_weights_(it); + + const int num_frames = 4; + + // Create a wider matrix and pass topRows as input + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c + 1) * (f + 1); + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + // Process with non-contiguous topRows(input_dim) + auto top_block = full_matrix.topRows(input_dim); + assert(top_block.outerStride() == total_rows); // Verify non-contiguous + + film.Process(top_block, condition, num_frames); + const auto& output = film.GetOutput(); + + const float scales[] = {2.0f, -1.0f, 0.5f}; + const float shifts[] = {10.0f, -5.0f, 3.0f}; + for (int c = 0; c < input_dim; c++) + { + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f) * scales[c] + shifts[c]; + assert(std::abs(output(c, f) - expected) < 1e-5f); + } + } +} + +void test_film_process_toprows_scale_only() +{ + const int condition_dim = 2; + const int input_dim = 4; + const int total_rows = 8; + + nam::FiLM film(condition_dim, input_dim, /*shift=*/false); + film.SetMaxBufferSize(64); + + std::vector weights(input_dim * condition_dim + input_dim, 0.0f); + const int bias_offset = input_dim * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = 3.0f; + weights[bias_offset + 2] = -1.0f; + weights[bias_offset + 3] = 0.5f; + auto it = weights.begin(); + film.set_weights_(it); + + const int num_frames = 3; + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix.setRandom(); + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + film.Process(full_matrix.topRows(input_dim), condition, num_frames); + const auto& output = film.GetOutput(); + + const float scales[] = {2.0f, 3.0f, -1.0f, 0.5f}; + for (int c = 0; c < input_dim; c++) + for (int f = 0; f < num_frames; f++) + { + const float expected = full_matrix(c, f) * scales[c]; + assert(std::abs(output(c, f) - expected) < 1e-5f); + } +} + +// Verify FiLM with topRows matches result from contiguous copy +void test_film_toprows_matches_contiguous() +{ + const int condition_dim = 2; + const int input_dim = 3; + const int total_rows = 6; + const int num_frames = 4; + + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + film.SetMaxBufferSize(64); + + std::vector weights((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = -1.0f; + weights[bias_offset + 2] = 0.5f; + weights[bias_offset + 3] = 10.0f; + weights[bias_offset + 4] = -5.0f; + weights[bias_offset + 5] = 3.0f; + auto it = weights.begin(); + film.set_weights_(it); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + full_matrix.setRandom(); + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + // Reference: contiguous copy + Eigen::MatrixXf contiguous = full_matrix.topRows(input_dim).eval(); + film.Process(contiguous, condition, num_frames); + Eigen::MatrixXf expected = film.GetOutput().leftCols(num_frames).eval(); + + // Test: non-contiguous topRows + film.Process(full_matrix.topRows(input_dim), condition, num_frames); + const auto& actual = film.GetOutput(); + + for (int c = 0; c < input_dim; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(actual(c, f) - expected(c, f)) < 1e-6f); +} + +// ============================================================ +// FiLM::Process_() (in-place) with non-contiguous input (topRows) +// ============================================================ + +void test_film_process_inplace_toprows() +{ + const int condition_dim = 2; + const int input_dim = 3; + const int total_rows = 6; + const int num_frames = 4; + + nam::FiLM film(condition_dim, input_dim, /*shift=*/true); + film.SetMaxBufferSize(64); + + std::vector weights((2 * input_dim) * condition_dim + (2 * input_dim), 0.0f); + const int bias_offset = (2 * input_dim) * condition_dim; + weights[bias_offset + 0] = 2.0f; + weights[bias_offset + 1] = -1.0f; + weights[bias_offset + 2] = 0.5f; + weights[bias_offset + 3] = 10.0f; + weights[bias_offset + 4] = -5.0f; + weights[bias_offset + 5] = 3.0f; + auto it = weights.begin(); + film.set_weights_(it); + + Eigen::MatrixXf full_matrix(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + full_matrix(c, f) = (float)(c + 1) * (f + 1); + Eigen::MatrixXf original = full_matrix; + + Eigen::MatrixXf condition(condition_dim, num_frames); + condition.setRandom(); + + // In-place process on topRows block + film.Process_(full_matrix.topRows(input_dim), condition, num_frames); + + const float scales[] = {2.0f, -1.0f, 0.5f}; + const float shifts[] = {10.0f, -5.0f, 3.0f}; + for (int c = 0; c < input_dim; c++) + for (int f = 0; f < num_frames; f++) + { + const float expected = original(c, f) * scales[c] + shifts[c]; + assert(std::abs(full_matrix(c, f) - expected) < 1e-5f); + } + + // Bottom rows should be unchanged + for (int c = input_dim; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_matrix(c, f) - original(c, f)) < 1e-6f); +} + +// ============================================================ +// GatingActivation with non-contiguous output (topRows) +// ============================================================ + +// Helper to create a non-owning shared_ptr for stack-allocated activations in tests +template +static nam::activations::Activation::Ptr make_test_ptr(T& activation) +{ + return nam::activations::Activation::Ptr(&activation, [](nam::activations::Activation*) {}); +} + +void test_gating_output_toprows() +{ + // Simulate wavenet.cpp pattern: + // input_block = _z.leftCols(num_frames) (contiguous, 2*bottleneck rows) + // output_block = _z.topRows(bottleneck).leftCols(num_frames) (non-contiguous output) + const int bottleneck = 3; + const int total_rows = 2 * bottleneck; + const int num_frames = 4; + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating(make_test_ptr(identity_act), make_test_ptr(sigmoid_act), bottleneck); + + // Input: contiguous (2*bottleneck x num_frames) + Eigen::MatrixXf input(total_rows, num_frames); + for (int c = 0; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + input(c, f) = (float)(c + 1) * 0.5f; + + // Output goes into topRows of a larger matrix + Eigen::MatrixXf output_matrix(total_rows, num_frames); + output_matrix.setConstant(-999.0f); // poison + auto output_block = output_matrix.topRows(bottleneck).leftCols(num_frames); + + gating.apply(input, output_block); + + // Verify output values are correct (not reading poison from wrong stride) + for (int f = 0; f < num_frames; f++) + { + for (int c = 0; c < bottleneck; c++) + { + const float input_val = input(c, f); // identity activation + const float gate_val = 1.0f / (1.0f + expf(-input(c + bottleneck, f))); // sigmoid + const float expected = input_val * gate_val; + assert(std::abs(output_matrix(c, f) - expected) < 1e-5f); + } + } + + // Bottom rows should still be poison (untouched) + for (int c = bottleneck; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(output_matrix(c, f) - (-999.0f)) < 1e-6f); +} + +void test_gating_toprows_matches_contiguous() +{ + const int bottleneck = 2; + const int total_rows = 2 * bottleneck; + const int num_frames = 3; + + nam::activations::ActivationReLU relu_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::GatingActivation gating(make_test_ptr(relu_act), make_test_ptr(sigmoid_act), bottleneck); + + Eigen::MatrixXf input(total_rows, num_frames); + input.setRandom(); + + // Reference: contiguous output + Eigen::MatrixXf contiguous_output(bottleneck, num_frames); + gating.apply(input, contiguous_output); + + // Test: non-contiguous output (topRows of larger matrix) + Eigen::MatrixXf full_output(total_rows, num_frames); + full_output.setZero(); + auto output_block = full_output.topRows(bottleneck).leftCols(num_frames); + gating.apply(input, output_block); + + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_output(c, f) - contiguous_output(c, f)) < 1e-6f); +} + +// ============================================================ +// BlendingActivation with non-contiguous output (topRows) +// ============================================================ + +void test_blending_output_toprows() +{ + const int bottleneck = 2; + const int total_rows = 2 * bottleneck; + const int num_frames = 3; + + nam::activations::ActivationIdentity identity_act; + nam::activations::ActivationSigmoid sigmoid_act; + nam::gating_activations::BlendingActivation blending( + make_test_ptr(identity_act), make_test_ptr(sigmoid_act), bottleneck); + + Eigen::MatrixXf input(total_rows, num_frames); + input.setRandom(); + + // Reference: contiguous output + Eigen::MatrixXf contiguous_output(bottleneck, num_frames); + blending.apply(input, contiguous_output); + + // Test: non-contiguous output (topRows of larger matrix) + Eigen::MatrixXf full_output(total_rows, num_frames); + full_output.setConstant(-999.0f); + auto output_block = full_output.topRows(bottleneck).leftCols(num_frames); + blending.apply(input, output_block); + + for (int c = 0; c < bottleneck; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_output(c, f) - contiguous_output(c, f)) < 1e-5f); + + // Bottom rows should be untouched + for (int c = bottleneck; c < total_rows; c++) + for (int f = 0; f < num_frames; f++) + assert(std::abs(full_output(c, f) - (-999.0f)) < 1e-6f); +} + +} // namespace test_noncontiguous_blocks