Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions NAM/conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
const int out_ch = (int)get_out_channels();
const int in_ch = (int)get_in_channels();
const size_t kernel_size = this->_weight.size();
const size_t weight_matrix_size = out_ch * in_ch;

// Fused kernel optimization for kernel_size=3
// Instead of 3 separate passes over output, fuse into single pass
Expand All @@ -282,7 +281,6 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
float* __restrict__ output_ptr = _output.data();

// Get weight pointers for all 3 taps
const size_t wsize = 16; // 4x4
const float* __restrict__ w0 = this->_weight[0].data();
const float* __restrict__ w1 = this->_weight[1].data();
const float* __restrict__ w2 = this->_weight[2].data();
Expand Down Expand Up @@ -600,6 +598,46 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
}
}
}
else if (out_ch == 8 && in_ch == 4)
{
// 8x4 fully unrolled
const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3];
const float w40 = weight_ptr[4], w50 = weight_ptr[5], w60 = weight_ptr[6], w70 = weight_ptr[7];
const float w01 = weight_ptr[8], w11 = weight_ptr[9], w21 = weight_ptr[10], w31 = weight_ptr[11];
const float w41 = weight_ptr[12], w51 = weight_ptr[13], w61 = weight_ptr[14], w71 = weight_ptr[15];
const float w02 = weight_ptr[16], w12 = weight_ptr[17], w22 = weight_ptr[18], w32 = weight_ptr[19];
const float w42 = weight_ptr[20], w52 = weight_ptr[21], w62 = weight_ptr[22], w72 = weight_ptr[23];
const float w03 = weight_ptr[24], w13 = weight_ptr[25], w23 = weight_ptr[26], w33 = weight_ptr[27];
const float w43 = weight_ptr[28], w53 = weight_ptr[29], w63 = weight_ptr[30], w73 = weight_ptr[31];
for (int f = 0; f < num_frames; f++)
{
const int in_off = f * 4;
const int out_off = f * 8;
const float i0 = input_ptr[in_off];
const float i1 = input_ptr[in_off + 1];
const float i2 = input_ptr[in_off + 2];
const float i3 = input_ptr[in_off + 3];
output_ptr[out_off] += w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3;
output_ptr[out_off + 1] += w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3;
output_ptr[out_off + 2] += w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3;
output_ptr[out_off + 3] += w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3;
output_ptr[out_off + 4] += w40 * i0 + w41 * i1 + w42 * i2 + w43 * i3;
output_ptr[out_off + 5] += w50 * i0 + w51 * i1 + w52 * i2 + w53 * i3;
output_ptr[out_off + 6] += w60 * i0 + w61 * i1 + w62 * i2 + w63 * i3;
output_ptr[out_off + 7] += w70 * i0 + w71 * i1 + w72 * i2 + w73 * i3;
}
}
else if (out_ch == 1 && in_ch == 4)
{
// 1x4 fully unrolled
const float w0 = weight_ptr[0], w1 = weight_ptr[1], w2 = weight_ptr[2], w3 = weight_ptr[3];
for (int f = 0; f < num_frames; f++)
{
const int in_off = f * 4;
output_ptr[f] += w0 * input_ptr[in_off] + w1 * input_ptr[in_off + 1] + w2 * input_ptr[in_off + 2]
+ w3 * input_ptr[in_off + 3];
}
}
else
{
// Fall back to Eigen for larger matrices where it's more efficient
Expand Down
70 changes: 62 additions & 8 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,17 +622,56 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7];
const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11];
const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15];
if (this->_do_bias)
{
const float b0 = this->_bias(0), b1 = this->_bias(1), b2 = this->_bias(2), b3 = this->_bias(3);
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
const float i0 = in_col[0];
const float i1 = in_col[1];
const float i2 = in_col[2];
const float i3 = in_col[3];
output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3 + b0;
output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3 + b1;
output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3 + b2;
output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3 + b3;
}
bias_fused = true;
}
else
{
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
const float i0 = in_col[0];
const float i1 = in_col[1];
const float i2 = in_col[2];
const float i3 = in_col[3];
output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3;
output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3;
output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3;
output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3;
}
}
}
else if (out_ch == 4 && in_ch == 6)
{
const float w00 = weight_ptr[0], w10 = weight_ptr[1], w20 = weight_ptr[2], w30 = weight_ptr[3];
const float w01 = weight_ptr[4], w11 = weight_ptr[5], w21 = weight_ptr[6], w31 = weight_ptr[7];
const float w02 = weight_ptr[8], w12 = weight_ptr[9], w22 = weight_ptr[10], w32 = weight_ptr[11];
const float w03 = weight_ptr[12], w13 = weight_ptr[13], w23 = weight_ptr[14], w33 = weight_ptr[15];
const float w04 = weight_ptr[16], w14 = weight_ptr[17], w24 = weight_ptr[18], w34 = weight_ptr[19];
const float w05 = weight_ptr[20], w15 = weight_ptr[21], w25 = weight_ptr[22], w35 = weight_ptr[23];
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
const float i0 = in_col[0];
const float i1 = in_col[1];
const float i2 = in_col[2];
const float i3 = in_col[3];
output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3;
output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3;
output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3;
output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3;
const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2];
const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5];
output_ptr[f * 4] = w00 * i0 + w01 * i1 + w02 * i2 + w03 * i3 + w04 * i4 + w05 * i5;
output_ptr[f * 4 + 1] = w10 * i0 + w11 * i1 + w12 * i2 + w13 * i3 + w14 * i4 + w15 * i5;
output_ptr[f * 4 + 2] = w20 * i0 + w21 * i1 + w22 * i2 + w23 * i3 + w24 * i4 + w25 * i5;
output_ptr[f * 4 + 3] = w30 * i0 + w31 * i1 + w32 * i2 + w33 * i3 + w34 * i4 + w35 * i5;
}
}
else if (out_ch == 6 && in_ch == 6)
Expand All @@ -650,6 +689,21 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
}
}
}
else if (out_ch == 8 && in_ch == 6)
{
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = input_ptr + f * in_stride;
float* __restrict__ out_col = output_ptr + f * 8;
const float i0 = in_col[0], i1 = in_col[1], i2 = in_col[2];
const float i3 = in_col[3], i4 = in_col[4], i5 = in_col[5];
for (int o = 0; o < 8; o++)
{
out_col[o] = weight_ptr[o] * i0 + weight_ptr[8 + o] * i1 + weight_ptr[16 + o] * i2
+ weight_ptr[24 + o] * i3 + weight_ptr[32 + o] * i4 + weight_ptr[40 + o] * i5;
}
}
}
else if (out_ch == 8 && in_ch == 8)
{
for (int f = 0; f < num_frames; f++)
Expand Down
5 changes: 5 additions & 0 deletions tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ int main()
test_conv1d::test_process_grouped_dilation();
test_conv1d::test_process_grouped_channel_isolation();
test_conv1d::test_get_num_weights_grouped();
test_conv1d::test_process_8x4_kernel6_matches_reference();
test_conv1d::test_process_1x4_kernel16_matches_reference();

test_conv_1x1::test_construct();
test_conv_1x1::test_construct_with_groups();
Expand All @@ -144,6 +146,9 @@ int main()
test_conv_1x1::test_process_underscore_grouped();
test_conv_1x1::test_set_max_buffer_size();
test_conv_1x1::test_process_multiple_calls();
test_conv_1x1::test_process_underscore_4x6_matches_reference();
test_conv_1x1::test_process_underscore_8x6_matches_reference();
test_conv_1x1::test_process_underscore_4x4_with_bias_matches_reference();

test_film::test_set_max_buffer_size();
test_film::test_process_bias_only();
Expand Down
99 changes: 99 additions & 0 deletions tools/test/test_conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,93 @@

namespace test_conv1d
{
void assert_close(const float actual, const float expected)
{
assert(std::abs(actual - expected) < 1.0e-4f);
}

void test_process_matches_reference(const int in_channels, const int out_channels, const int kernel_size,
const bool do_bias, const int dilation, const int num_frames)
{
nam::Conv1D conv;
conv.set_size_(in_channels, out_channels, kernel_size, do_bias, dilation);

std::vector<Eigen::MatrixXf> reference_weights;
reference_weights.reserve(kernel_size);
for (int k = 0; k < kernel_size; k++)
reference_weights.emplace_back(out_channels, in_channels);
Eigen::VectorXf reference_bias(out_channels);

std::vector<float> weights;
weights.reserve(out_channels * in_channels * kernel_size + (do_bias ? out_channels : 0));
for (int o = 0; o < out_channels; o++)
{
for (int i = 0; i < in_channels; i++)
{
for (int k = 0; k < kernel_size; k++)
{
const float value = 0.011f * static_cast<float>(o + 1) + 0.007f * static_cast<float>(i + 1)
- 0.003f * static_cast<float>(k + 1);
reference_weights[k](o, i) = value;
weights.push_back(value);
}
}
}
for (int o = 0; o < out_channels; o++)
{
const float value = -0.05f + 0.019f * static_cast<float>(o + 1);
reference_bias(o) = value;
if (do_bias)
weights.push_back(value);
}

auto it = weights.begin();
conv.set_weights_(it);
conv.SetMaxBufferSize(64);

Eigen::MatrixXf input(in_channels, num_frames);
for (int f = 0; f < num_frames; f++)
{
for (int i = 0; i < in_channels; i++)
{
input(i, f) = 0.21f * static_cast<float>(i + 1) - 0.037f * static_cast<float>(f + 1)
+ 0.004f * static_cast<float>((i + 1) * (f + 1));
}
}

Eigen::MatrixXf expected(out_channels, num_frames);
expected.setZero();
for (int f = 0; f < num_frames; f++)
{
for (int o = 0; o < out_channels; o++)
{
float sum = do_bias ? reference_bias(o) : 0.0f;
for (int k = 0; k < kernel_size; k++)
{
const int source_frame = f - dilation * (kernel_size - 1 - k);
if (source_frame < 0)
continue;
for (int i = 0; i < in_channels; i++)
sum += reference_weights[k](o, i) * input(i, source_frame);
}
expected(o, f) = sum;
}
}

conv.Process(input, num_frames);
auto output = conv.GetOutput().leftCols(num_frames);

assert(output.rows() == out_channels);
assert(output.cols() == num_frames);
for (int f = 0; f < num_frames; f++)
{
for (int o = 0; o < out_channels; o++)
{
assert_close(output(o, f), expected(o, f));
}
}
}

// Test basic construction
void test_construct()
{
Expand Down Expand Up @@ -848,4 +935,16 @@ void test_get_num_weights_grouped()
actual = conv_4groups.get_num_weights();
assert(actual == expected);
}

void test_process_8x4_kernel6_matches_reference()
{
test_process_matches_reference(
/*in_channels=*/4, /*out_channels=*/8, /*kernel_size=*/6, /*do_bias=*/true, /*dilation=*/3, /*num_frames=*/23);
}

void test_process_1x4_kernel16_matches_reference()
{
test_process_matches_reference(
/*in_channels=*/4, /*out_channels=*/1, /*kernel_size=*/16, /*do_bias=*/true, /*dilation=*/1, /*num_frames=*/23);
}
}; // namespace test_conv1d
79 changes: 79 additions & 0 deletions tools/test/test_conv_1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,70 @@

namespace test_conv_1x1
{
void assert_close(const float actual, const float expected)
{
assert(std::abs(actual - expected) < 1.0e-5f);
}

void test_process_underscore_matches_reference(const int in_channels, const int out_channels, const bool do_bias)
{
nam::Conv1x1 conv(in_channels, out_channels, do_bias);
const int num_frames = 5;

Eigen::MatrixXf reference_weight(out_channels, in_channels);
Eigen::VectorXf reference_bias(out_channels);
std::vector<float> weights;
weights.reserve(out_channels * in_channels + (do_bias ? out_channels : 0));

for (int o = 0; o < out_channels; o++)
{
for (int i = 0; i < in_channels; i++)
{
const float value = 0.17f * static_cast<float>(o + 1) - 0.031f * static_cast<float>(i + 1);
reference_weight(o, i) = value;
weights.push_back(value);
}
}
for (int o = 0; o < out_channels; o++)
{
const float value = -0.09f + 0.023f * static_cast<float>(o + 1);
reference_bias(o) = value;
if (do_bias)
weights.push_back(value);
}

auto it = weights.begin();
conv.set_weights_(it);
conv.SetMaxBufferSize(64);

Eigen::MatrixXf input(in_channels, num_frames);
for (int f = 0; f < num_frames; f++)
{
for (int i = 0; i < in_channels; i++)
{
input(i, f) = 0.25f * static_cast<float>(i + 1) - 0.11f * static_cast<float>(f + 1)
+ 0.013f * static_cast<float>((i + 1) * (f + 1));
}
}

Eigen::MatrixXf expected = reference_weight * input;
if (do_bias)
expected.colwise() += reference_bias;

conv.process_(input, num_frames);
auto output = conv.GetOutput().leftCols(num_frames);

assert(output.rows() == out_channels);
assert(output.cols() == num_frames);
for (int f = 0; f < num_frames; f++)
{
for (int o = 0; o < out_channels; o++)
{
assert_close(output(o, f), expected(o, f));
}
}
}

// Test basic construction
void test_construct()
{
Expand Down Expand Up @@ -492,4 +556,19 @@ void test_process_multiple_calls()
assert(std::abs(output2(0, 0) - 3.0f) < 0.01f);
assert(std::abs(output2(1, 0) - 4.0f) < 0.01f);
}

void test_process_underscore_4x6_matches_reference()
{
test_process_underscore_matches_reference(/*in_channels=*/6, /*out_channels=*/4, /*do_bias=*/false);
}

void test_process_underscore_8x6_matches_reference()
{
test_process_underscore_matches_reference(/*in_channels=*/6, /*out_channels=*/8, /*do_bias=*/false);
}

void test_process_underscore_4x4_with_bias_matches_reference()
{
test_process_underscore_matches_reference(/*in_channels=*/4, /*out_channels=*/4, /*do_bias=*/true);
}
} // namespace test_conv_1x1
Loading