Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions NAM/convnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,25 +322,40 @@ void nam::convnet::ConvNet::_rewind_buffers_()
this->Buffer::_rewind_buffers_();
}

// Factory
std::unique_ptr<nam::DSP> nam::convnet::Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate)
// Config parser
nam::convnet::ConvNetConfig nam::convnet::parse_config_json(const nlohmann::json& config)
{
const int channels = config["channels"];
const std::vector<int> dilations = config["dilations"];
const bool batchnorm = config["batchnorm"];
ConvNetConfig c;
c.channels = config["channels"];
c.dilations = config["dilations"].get<std::vector<int>>();
c.batchnorm = config["batchnorm"];
// Parse JSON into typed ActivationConfig at model loading boundary
const activations::ActivationConfig activation_config =
activations::ActivationConfig::from_json(config["activation"]);
const int groups = config.value("groups", 1); // defaults to 1
c.activation = activations::ActivationConfig::from_json(config["activation"]);
c.groups = config.value("groups", 1); // defaults to 1
// Default to 1 channel in/out for backward compatibility
const int in_channels = config.value("in_channels", 1);
const int out_channels = config.value("out_channels", 1);
return std::make_unique<nam::convnet::ConvNet>(
in_channels, out_channels, channels, dilations, batchnorm, activation_config, weights, expectedSampleRate, groups);
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
return c;
}

// ConvNetConfig::create()
std::unique_ptr<nam::DSP> nam::convnet::ConvNetConfig::create(std::vector<float> weights, double sampleRate)
{
return std::make_unique<nam::convnet::ConvNet>(in_channels, out_channels, channels, dilations, batchnorm, activation,
weights, sampleRate, groups);
}

// Config parser for ConfigParserRegistry
std::unique_ptr<nam::ModelConfig> nam::convnet::create_config(const nlohmann::json& config, double sampleRate)
{
(void)sampleRate;
auto c = std::make_unique<ConvNetConfig>();
auto parsed = parse_config_json(config);
*c = parsed;
return c;
}

namespace
{
static nam::factory::Helper _register_ConvNet("ConvNet", nam::convnet::Factory);
static nam::ConfigParserHelper _register_ConvNet("ConvNet", nam::convnet::create_config);
}
26 changes: 20 additions & 6 deletions NAM/convnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,27 @@ class ConvNet : public Buffer
int PrewarmSamples() override { return mPrewarmSamples; };
};

/// \brief Factory function to instantiate ConvNet from JSON
/// \brief Configuration for a ConvNet model
struct ConvNetConfig : public ModelConfig
{
int channels;
std::vector<int> dilations;
bool batchnorm;
activations::ActivationConfig activation;
int groups;
int in_channels;
int out_channels;

std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
};

/// \brief Parse ConvNet configuration from JSON
/// \param config JSON configuration object
/// \param weights Model weights vector
/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown)
/// \return Unique pointer to a DSP object (ConvNet instance)
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate);
/// \return ConvNetConfig
ConvNetConfig parse_config_json(const nlohmann::json& config);

/// \brief Config parser for ConfigParserRegistry
std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);

}; // namespace convnet
}; // namespace nam
38 changes: 30 additions & 8 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,38 @@ void nam::Linear::process(NAM_SAMPLE** input, NAM_SAMPLE** output, const int num
nam::Buffer::_advance_input_buffer_(num_frames);
}

// Factory
std::unique_ptr<nam::DSP> nam::linear::Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate)
// Config parser
nam::linear::LinearConfig nam::linear::parse_config_json(const nlohmann::json& config)
{
const int receptive_field = config["receptive_field"];
const bool bias = config["bias"];
LinearConfig c;
c.receptive_field = config["receptive_field"];
c.bias = config["bias"];
// Default to 1 channel in/out for backward compatibility
const int in_channels = config.value("in_channels", 1);
const int out_channels = config.value("out_channels", 1);
return std::make_unique<nam::Linear>(in_channels, out_channels, receptive_field, bias, weights, expectedSampleRate);
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
return c;
}

// LinearConfig::create()
std::unique_ptr<nam::DSP> nam::linear::LinearConfig::create(std::vector<float> weights, double sampleRate)
{
return std::make_unique<nam::Linear>(in_channels, out_channels, receptive_field, bias, weights, sampleRate);
}

// Config parser for ConfigParserRegistry
std::unique_ptr<nam::ModelConfig> nam::linear::create_config(const nlohmann::json& config, double sampleRate)
{
(void)sampleRate;
auto c = std::make_unique<LinearConfig>();
auto parsed = parse_config_json(config);
*c = parsed;
return c;
}

// Register the config parser
namespace
{
static nam::ConfigParserHelper _register_Linear("Linear", nam::linear::create_config);
}

// NN modules =================================================================
Expand Down
28 changes: 22 additions & 6 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "activations.h"
#include "json.hpp"
#include "model_config.h"

#ifdef NAM_SAMPLE_FLOAT
#define NAM_SAMPLE float
Expand Down Expand Up @@ -258,13 +259,28 @@ class Linear : public Buffer

namespace linear
{
/// \brief Factory function to instantiate Linear model from JSON

/// \brief Configuration for a Linear model
struct LinearConfig : public ModelConfig
{
int receptive_field;
bool bias;
int in_channels;
int out_channels;

std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
};

/// \brief Parse Linear configuration from JSON
/// \param config JSON configuration object
/// \return LinearConfig
LinearConfig parse_config_json(const nlohmann::json& config);

/// \brief Config parser for ConfigParserRegistry
/// \param config JSON configuration object
/// \param weights Model weights vector
/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown)
/// \return Unique pointer to a DSP object (Linear instance)
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate);
/// \param sampleRate Expected sample rate in Hz
/// \return unique_ptr<ModelConfig> wrapping a LinearConfig
std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);
} // namespace linear

// NN modules =================================================================
Expand Down
102 changes: 53 additions & 49 deletions NAM/get_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <unordered_set>

#include "dsp.h"
#include "registry.h"
#include "json.hpp"
#include "lstm.h"
#include "convnet.h"
#include "wavenet.h"
#include "get_dsp.h"
#include "model_config.h"

namespace nam
{
Expand Down Expand Up @@ -146,62 +143,69 @@ std::unique_ptr<DSP> get_dsp(const nlohmann::json& config, dspData& returnedConf
return get_dsp(conf);
}

struct OptionalValue
// =============================================================================
// Unified construction path
// =============================================================================

std::unique_ptr<ModelConfig> parse_model_config_json(const std::string& architecture, const nlohmann::json& config,
double sample_rate)
{
return ConfigParserRegistry::instance().parse(architecture, config, sample_rate);
}

namespace
{
bool have = false;
double value = 0.0;
};

void apply_metadata(DSP& dsp, const ModelMetadata& metadata)
{
if (metadata.loudness.has_value())
dsp.SetLoudness(metadata.loudness.value());
if (metadata.input_level.has_value())
dsp.SetInputLevel(metadata.input_level.value());
if (metadata.output_level.has_value())
dsp.SetOutputLevel(metadata.output_level.value());
}

} // anonymous namespace

std::unique_ptr<DSP> create_dsp(std::unique_ptr<ModelConfig> config, std::vector<float> weights,
const ModelMetadata& metadata)
{
auto out = config->create(std::move(weights), metadata.sample_rate);
apply_metadata(*out, metadata);
// "pre-warm" the model to settle initial conditions
// Can this be removed now that it's part of Reset()?
out->prewarm();
return out;
}

// =============================================================================
// get_dsp(dspData&) — now uses unified path
// =============================================================================

std::unique_ptr<DSP> get_dsp(dspData& conf)
{
verify_config_version(conf.version);

auto& architecture = conf.architecture;
nlohmann::json& config = conf.config;
std::vector<float>& weights = conf.weights;
OptionalValue loudness, inputLevel, outputLevel;

auto AssignOptional = [&conf](const std::string key, OptionalValue& v) {
if (conf.metadata.find(key) != conf.metadata.end())
{
if (!conf.metadata[key].is_null())
{
v.value = conf.metadata[key];
v.have = true;
}
}
};
// Extract metadata from JSON
ModelMetadata metadata;
metadata.version = conf.version;
metadata.sample_rate = conf.expected_sample_rate;

if (!conf.metadata.is_null())
{
AssignOptional("loudness", loudness);
AssignOptional("input_level_dbu", inputLevel);
AssignOptional("output_level_dbu", outputLevel);
}
const double expectedSampleRate = conf.expected_sample_rate;

// Initialize using registry-based factory
std::unique_ptr<DSP> out =
nam::factory::FactoryRegistry::instance().create(architecture, config, weights, expectedSampleRate);

if (loudness.have)
{
out->SetLoudness(loudness.value);
}
if (inputLevel.have)
{
out->SetInputLevel(inputLevel.value);
}
if (outputLevel.have)
{
out->SetOutputLevel(outputLevel.value);
auto extract = [&conf](const std::string& key) -> std::optional<double> {
if (conf.metadata.find(key) != conf.metadata.end() && !conf.metadata[key].is_null())
return conf.metadata[key].get<double>();
return std::nullopt;
};
metadata.loudness = extract("loudness");
metadata.input_level = extract("input_level_dbu");
metadata.output_level = extract("output_level_dbu");
}

// "pre-warm" the model to settle initial conditions
// Can this be removed now that it's part of Reset()?
out->prewarm();

return out;
auto model_config = ConfigParserRegistry::instance().parse(conf.architecture, conf.config, conf.expected_sample_rate);
return create_dsp(std::move(model_config), std::move(conf.weights), metadata);
}

double get_sample_rate_from_nam_file(const nlohmann::json& j)
Expand Down
40 changes: 28 additions & 12 deletions NAM/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,22 +163,38 @@ void nam::lstm::LSTM::_process_sample()
this->_output.noalias() += this->_head_bias;
}

// Factory to instantiate from nlohmann json
std::unique_ptr<nam::DSP> nam::lstm::Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate)
// Config parser
nam::lstm::LSTMConfig nam::lstm::parse_config_json(const nlohmann::json& config)
{
const int num_layers = config["num_layers"];
const int input_size = config["input_size"];
const int hidden_size = config["hidden_size"];
LSTMConfig c;
c.num_layers = config["num_layers"];
c.input_size = config["input_size"];
c.hidden_size = config["hidden_size"];
// Default to 1 channel in/out for backward compatibility
const int in_channels = config.value("in_channels", 1);
const int out_channels = config.value("out_channels", 1);
return std::make_unique<nam::lstm::LSTM>(
in_channels, out_channels, num_layers, input_size, hidden_size, weights, expectedSampleRate);
c.in_channels = config.value("in_channels", 1);
c.out_channels = config.value("out_channels", 1);
return c;
}

// Register the factory
// LSTMConfig::create()
std::unique_ptr<nam::DSP> nam::lstm::LSTMConfig::create(std::vector<float> weights, double sampleRate)
{
return std::make_unique<nam::lstm::LSTM>(in_channels, out_channels, num_layers, input_size, hidden_size, weights,
sampleRate);
}

// Config parser for ConfigParserRegistry
std::unique_ptr<nam::ModelConfig> nam::lstm::create_config(const nlohmann::json& config, double sampleRate)
{
(void)sampleRate;
auto c = std::make_unique<LSTMConfig>();
auto parsed = parse_config_json(config);
*c = parsed;
return c;
}

// Register the config parser
namespace
{
static nam::factory::Helper _register_LSTM("LSTM", nam::lstm::Factory);
static nam::ConfigParserHelper _register_LSTM("LSTM", nam::lstm::create_config);
}
24 changes: 18 additions & 6 deletions NAM/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,25 @@ class LSTM : public DSP
Eigen::VectorXf _output;
};

/// \brief Factory function to instantiate LSTM from JSON
/// \brief Configuration for an LSTM model
struct LSTMConfig : public ModelConfig
{
int num_layers;
int input_size;
int hidden_size;
int in_channels;
int out_channels;

std::unique_ptr<DSP> create(std::vector<float> weights, double sampleRate) override;
};

/// \brief Parse LSTM configuration from JSON
/// \param config JSON configuration object
/// \param weights Model weights vector
/// \param expectedSampleRate Expected sample rate in Hz (-1.0 if unknown)
/// \return Unique pointer to a DSP object (LSTM instance)
std::unique_ptr<DSP> Factory(const nlohmann::json& config, std::vector<float>& weights,
const double expectedSampleRate);
/// \return LSTMConfig
LSTMConfig parse_config_json(const nlohmann::json& config);

/// \brief Config parser for ConfigParserRegistry
std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate);

}; // namespace lstm
}; // namespace nam
Loading