From 7cceadb19ddad4f0d29e014f0267f83aa4e3ab4a Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Tue, 23 Jun 2026 14:35:34 -0700 Subject: [PATCH] Optimize additional small convolution paths --- NAM/conv1d.cpp | 42 ++++++++++++++- NAM/dsp.cpp | 70 ++++++++++++++++++++++--- tools/run_tests.cpp | 5 ++ tools/test/test_conv1d.cpp | 99 ++++++++++++++++++++++++++++++++++++ tools/test/test_conv_1x1.cpp | 79 ++++++++++++++++++++++++++++ 5 files changed, 285 insertions(+), 10 deletions(-) diff --git a/NAM/conv1d.cpp b/NAM/conv1d.cpp index b561786c..d6f1a84c 100644 --- a/NAM/conv1d.cpp +++ b/NAM/conv1d.cpp @@ -264,7 +264,6 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) const int out_ch = (int)get_out_channels(); const int in_ch = (int)get_in_channels(); const size_t kernel_size = this->_weight.size(); - const size_t weight_matrix_size = out_ch * in_ch; // Fused kernel optimization for kernel_size=3 // Instead of 3 separate passes over output, fuse into single pass @@ -282,7 +281,6 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) float* __restrict__ output_ptr = _output.data(); // Get weight pointers for all 3 taps - const size_t wsize = 16; // 4x4 const float* __restrict__ w0 = this->_weight[0].data(); const float* __restrict__ w1 = this->_weight[1].data(); const float* __restrict__ w2 = this->_weight[2].data(); @@ -600,6 +598,46 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) } } } + else if (out_ch == 8 && in_ch == 4) + { + // 8x4 fully unrolled + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w40 = weight_ptr[4], w50 = weight_ptr[5], w60 = weight_ptr[6], w70 = weight_ptr[7]; + const float w01 = weight_ptr[8], w11 = weight_ptr[9], w21 = weight_ptr[10], w31 = weight_ptr[11]; + const float w41 = weight_ptr[12], w51 = weight_ptr[13], w61 = weight_ptr[14], w71 = weight_ptr[15]; + const float w02 = weight_ptr[16], w12 = weight_ptr[17], w22 = weight_ptr[18], w32 = weight_ptr[19]; + const float w42 = weight_ptr[20], w52 = weight_ptr[21], w62 = weight_ptr[22], w72 = weight_ptr[23]; + const float w03 = weight_ptr[24], w13 = weight_ptr[25], w23 = weight_ptr[26], w33 = weight_ptr[27]; + const float w43 = weight_ptr[28], w53 = weight_ptr[29], w63 = weight_ptr[30], w73 = weight_ptr[31]; + for (int f = 0; f < num_frames; f++) + { + const int in_off = f * 4; + const int out_off = f * 8; + const float i0 = input_ptr[in_off]; + const float i1 = input_ptr[in_off + 1]; + const float i2 = input_ptr[in_off + 2]; + const float i3 = input_ptr[in_off + 3]; + output_ptr[out_off] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[out_off + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[out_off + 2] += w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[out_off + 3] += w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + output_ptr[out_off + 4] += w40 * i0 + w41 * i1 + w42 * i2 + w43 * i3; + output_ptr[out_off + 5] += w50 * i0 + w51 * i1 + w52 * i2 + w53 * i3; + output_ptr[out_off + 6] += w60 * i0 + w61 * i1 + w62 * i2 + w63 * i3; + output_ptr[out_off + 7] += w70 * i0 + w71 * i1 + w72 * i2 + w73 * i3; + } + } + else if (out_ch == 1 && in_ch == 4) + { + // 1x4 fully unrolled + const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2], w3 = weight_ptr[3]; + for (int f = 0; f < num_frames; f++) + { + const int in_off = f * 4; + output_ptr[f] += w0 * input_ptr[in_off] + w1 * input_ptr[in_off + 1] + w2 * input_ptr[in_off + 2] + + w3 * input_ptr[in_off + 3]; + } + } else { // Fall back to Eigen for larger matrices where it's more efficient diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index 0aea735d..54e674a8 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -622,17 +622,56 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; + if (this->_do_bias) + { + const float b0 = this->_bias(0), b1 = this->_bias(1), b2 = this->_bias(2), b3 = this->_bias(3); + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + const float i3 = in_col[3]; + output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3 + b0; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3 + b1; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3 + b2; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3 + b3; + } + bias_fused = true; + } + else + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + const float i0 = in_col[0]; + const float i1 = in_col[1]; + const float i2 = in_col[2]; + const float i3 = in_col[3]; + output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + } + } + } + else if (out_ch == 4 && in_ch == 6) + { + const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3]; + const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7]; + const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11]; + const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15]; + const float w04 = weight_ptr[16], w14 = weight_ptr[17], w24 = weight_ptr[18], w34 = weight_ptr[19]; + const float w05 = weight_ptr[20], w15 = weight_ptr[21], w25 = weight_ptr[22], w35 = weight_ptr[23]; for (int f = 0; f < num_frames; f++) { const float* __restrict__ in_col = input_ptr + f * in_stride; - const float i0 = in_col[0]; - const float i1 = in_col[1]; - const float i2 = in_col[2]; - const float i3 = in_col[3]; - output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3; - output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3; - output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3; - output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; + const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; + output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3 + w04 * i4 + w05 * i5; + output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3 + w14 * i4 + w15 * i5; + output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3 + w24 * i4 + w25 * i5; + output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3 + w34 * i4 + w35 * i5; } } else if (out_ch == 6 && in_ch == 6) @@ -650,6 +689,21 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons } } } + else if (out_ch == 8 && in_ch == 6) + { + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = input_ptr + f * in_stride; + float* __restrict__ out_col = output_ptr + f * 8; + const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2]; + const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5]; + for (int o = 0; o < 8; o++) + { + out_col[o] = weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2 + + weight_ptr[24 + o] * i3 + weight_ptr[32 + o] * i4 + weight_ptr[40 + o] * i5; + } + } + } else if (out_ch == 8 && in_ch == 8) { for (int f = 0; f < num_frames; f++) diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index fab71f8e..d74bc5f3 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -129,6 +129,8 @@ int main() test_conv1d::test_process_grouped_dilation(); test_conv1d::test_process_grouped_channel_isolation(); test_conv1d::test_get_num_weights_grouped(); + test_conv1d::test_process_8x4_kernel6_matches_reference(); + test_conv1d::test_process_1x4_kernel16_matches_reference(); test_conv_1x1::test_construct(); test_conv_1x1::test_construct_with_groups(); @@ -144,6 +146,9 @@ int main() test_conv_1x1::test_process_underscore_grouped(); test_conv_1x1::test_set_max_buffer_size(); test_conv_1x1::test_process_multiple_calls(); + test_conv_1x1::test_process_underscore_4x6_matches_reference(); + test_conv_1x1::test_process_underscore_8x6_matches_reference(); + test_conv_1x1::test_process_underscore_4x4_with_bias_matches_reference(); test_film::test_set_max_buffer_size(); test_film::test_process_bias_only(); diff --git a/tools/test/test_conv1d.cpp b/tools/test/test_conv1d.cpp index 900eea0c..cc217403 100644 --- a/tools/test/test_conv1d.cpp +++ b/tools/test/test_conv1d.cpp @@ -10,6 +10,93 @@ namespace test_conv1d { +void assert_close(const float actual, const float expected) +{ + assert(std::abs(actual - expected) < 1.0e-4f); +} + +void test_process_matches_reference(const int in_channels, const int out_channels, const int kernel_size, + const bool do_bias, const int dilation, const int num_frames) +{ + nam::Conv1D conv; + conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation); + + std::vector reference_weights; + reference_weights.reserve(kernel_size); + for (int k = 0; k < kernel_size; k++) + reference_weights.emplace_back(out_channels, in_channels); + Eigen::VectorXf reference_bias(out_channels); + + std::vector weights; + weights.reserve(out_channels * in_channels * kernel_size + (do_bias ? out_channels : 0)); + for (int o = 0; o < out_channels; o++) + { + for (int i = 0; i < in_channels; i++) + { + for (int k = 0; k < kernel_size; k++) + { + const float value = 0.011f * static_cast(o + 1) + 0.007f * static_cast(i + 1) + - 0.003f * static_cast(k + 1); + reference_weights[k](o, i) = value; + weights.push_back(value); + } + } + } + for (int o = 0; o < out_channels; o++) + { + const float value = -0.05f + 0.019f * static_cast(o + 1); + reference_bias(o) = value; + if (do_bias) + weights.push_back(value); + } + + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf input(in_channels, num_frames); + for (int f = 0; f < num_frames; f++) + { + for (int i = 0; i < in_channels; i++) + { + input(i, f) = 0.21f * static_cast(i + 1) - 0.037f * static_cast(f + 1) + + 0.004f * static_cast((i + 1) * (f + 1)); + } + } + + Eigen::MatrixXf expected(out_channels, num_frames); + expected.setZero(); + for (int f = 0; f < num_frames; f++) + { + for (int o = 0; o < out_channels; o++) + { + float sum = do_bias ? reference_bias(o) : 0.0f; + for (int k = 0; k < kernel_size; k++) + { + const int source_frame = f - dilation * (kernel_size - 1 - k); + if (source_frame < 0) + continue; + for (int i = 0; i < in_channels; i++) + sum += reference_weights[k](o, i) * input(i, source_frame); + } + expected(o, f) = sum; + } + } + + conv.Process(input, num_frames); + auto output = conv.GetOutput().leftCols(num_frames); + + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + for (int f = 0; f < num_frames; f++) + { + for (int o = 0; o < out_channels; o++) + { + assert_close(output(o, f), expected(o, f)); + } + } +} + // Test basic construction void test_construct() { @@ -848,4 +935,16 @@ void test_get_num_weights_grouped() actual = conv_4groups.get_num_weights(); assert(actual == expected); } + +void test_process_8x4_kernel6_matches_reference() +{ + test_process_matches_reference( + /*in_channels=*/4, /*out_channels=*/8, /*kernel_size=*/6, /*do_bias=*/true, /*dilation=*/3, /*num_frames=*/23); +} + +void test_process_1x4_kernel16_matches_reference() +{ + test_process_matches_reference( + /*in_channels=*/4, /*out_channels=*/1, /*kernel_size=*/16, /*do_bias=*/true, /*dilation=*/1, /*num_frames=*/23); +} }; // namespace test_conv1d diff --git a/tools/test/test_conv_1x1.cpp b/tools/test/test_conv_1x1.cpp index cb3e2348..a46bc307 100644 --- a/tools/test/test_conv_1x1.cpp +++ b/tools/test/test_conv_1x1.cpp @@ -11,6 +11,70 @@ namespace test_conv_1x1 { +void assert_close(const float actual, const float expected) +{ + assert(std::abs(actual - expected) < 1.0e-5f); +} + +void test_process_underscore_matches_reference(const int in_channels, const int out_channels, const bool do_bias) +{ + nam::Conv1x1 conv(in_channels, out_channels, do_bias); + const int num_frames = 5; + + Eigen::MatrixXf reference_weight(out_channels, in_channels); + Eigen::VectorXf reference_bias(out_channels); + std::vector weights; + weights.reserve(out_channels * in_channels + (do_bias ? out_channels : 0)); + + for (int o = 0; o < out_channels; o++) + { + for (int i = 0; i < in_channels; i++) + { + const float value = 0.17f * static_cast(o + 1) - 0.031f * static_cast(i + 1); + reference_weight(o, i) = value; + weights.push_back(value); + } + } + for (int o = 0; o < out_channels; o++) + { + const float value = -0.09f + 0.023f * static_cast(o + 1); + reference_bias(o) = value; + if (do_bias) + weights.push_back(value); + } + + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(64); + + Eigen::MatrixXf input(in_channels, num_frames); + for (int f = 0; f < num_frames; f++) + { + for (int i = 0; i < in_channels; i++) + { + input(i, f) = 0.25f * static_cast(i + 1) - 0.11f * static_cast(f + 1) + + 0.013f * static_cast((i + 1) * (f + 1)); + } + } + + Eigen::MatrixXf expected = reference_weight * input; + if (do_bias) + expected.colwise() += reference_bias; + + conv.process_(input, num_frames); + auto output = conv.GetOutput().leftCols(num_frames); + + assert(output.rows() == out_channels); + assert(output.cols() == num_frames); + for (int f = 0; f < num_frames; f++) + { + for (int o = 0; o < out_channels; o++) + { + assert_close(output(o, f), expected(o, f)); + } + } +} + // Test basic construction void test_construct() { @@ -492,4 +556,19 @@ void test_process_multiple_calls() assert(std::abs(output2(0, 0) - 3.0f) < 0.01f); assert(std::abs(output2(1, 0) - 4.0f) < 0.01f); } + +void test_process_underscore_4x6_matches_reference() +{ + test_process_underscore_matches_reference(/*in_channels=*/6, /*out_channels=*/4, /*do_bias=*/false); +} + +void test_process_underscore_8x6_matches_reference() +{ + test_process_underscore_matches_reference(/*in_channels=*/6, /*out_channels=*/8, /*do_bias=*/false); +} + +void test_process_underscore_4x4_with_bias_matches_reference() +{ + test_process_underscore_matches_reference(/*in_channels=*/4, /*out_channels=*/4, /*do_bias=*/true); +} } // namespace test_conv_1x1