Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,26 @@ jobs:
./build/tools/run_tests
./build/tools/benchmodel ./example_models/wavenet.nam
./build/tools/benchmodel ./example_models/lstm.nam

build-ubuntu-inline-gemm:
name: Build Ubuntu (Inline GEMM)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3.3.0
with:
submodules: recursive

- name: Build Tools
working-directory: ${{github.workspace}}/build
env:
CXX: clang++
run: |
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-DNAM_USE_INLINE_GEMM"
cmake --build . -j4

- name: Run tests
working-directory: ${{github.workspace}}
run: |
./build/tools/run_tests
./build/tools/benchmodel ./example_models/wavenet.nam
./build/tools/benchmodel ./example_models/lstm.nam
15 changes: 13 additions & 2 deletions NAM/activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ std::unordered_map<std::string, nam::activations::Activation::Ptr> nam::activati
{"PReLU", make_singleton_ptr(_PRELU)},
{"Softsign", make_singleton_ptr(_SOFTSIGN)}};

// Variables to hold previous instances of activations when we replace them with fast or LUT implementations
nam::activations::Activation::Ptr tanh_bak = nullptr;
nam::activations::Activation::Ptr sigmoid_bak = nullptr;
nam::activations::Activation::Ptr silu_bak = nullptr;

nam::activations::Activation::Ptr nam::activations::Activation::get_activation(const std::string name)
{
Expand Down Expand Up @@ -197,9 +199,14 @@ void nam::activations::Activation::enable_lut(std::string function_name, float m
fn = sigmoid;
sigmoid_bak = _activations["Sigmoid"];
}
else if (function_name == "SiLU")
{
fn = swish;
silu_bak = _activations["SiLU"];
}
else
{
throw std::runtime_error("Tried to enable LUT for a function other than Tanh or Sigmoid");
throw std::runtime_error("Tried to enable LUT for a function other than Tanh, Sigmoid, or SiLU");
}
_activations[function_name] = std::make_shared<FastLUTActivation>(min, max, n_points, fn);
}
Expand All @@ -214,8 +221,12 @@ void nam::activations::Activation::disable_lut(std::string function_name)
{
_activations["Sigmoid"] = sigmoid_bak;
}
else if (function_name == "SiLU")
{
_activations["SiLU"] = silu_bak;
}
else
{
throw std::runtime_error("Tried to disable LUT for a function other than Tanh or Sigmoid");
throw std::runtime_error("Tried to disable LUT for a function other than Tanh, Sigmoid, or SiLU");
}
}
43 changes: 18 additions & 25 deletions NAM/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,12 @@ inline float swish(float x)

inline float hardswish(float x)
{
if (x <= -3.0)
{
return 0;
}
else if (x >= 3.0)
{
return x;
}
else
{
return x * (x + 3.0) / 6.0;
}
// Branchless implementation using clamp
// hardswish(x) = x * relu6(x + 3) / 6
// = x * clamp(x + 3, 0, 6) / 6
const float t = x + 3.0f;
const float clamped = t < 0.0f ? 0.0f : (t > 6.0f ? 6.0f : t);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting if this really is better? I'd be surprised if a compiler wouldn't figure out that these are the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was surprised by this too, but it does make a difference. I can share the microbenchmark if you are interested in having a look.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah I trust ya. sgtm

return x * clamped * (1.0f / 6.0f);
}

inline float softsign(float x)
Expand All @@ -145,9 +139,18 @@ class Activation
Activation() = default;
virtual ~Activation() = default;
virtual void apply(Eigen::MatrixXf& matrix) { apply(matrix.data(), matrix.rows() * matrix.cols()); }
virtual void apply(Eigen::Block<Eigen::MatrixXf> block) { apply(block.data(), block.rows() * block.cols()); }
virtual void apply(Eigen::Block<Eigen::MatrixXf> block)
{
// Block must be contiguous in memory (outerStride == rows) for flat data() access.
// Non-contiguous blocks (e.g. topRows() of a wider matrix) would read/write wrong elements.
assert(block.outerStride() == block.rows());
apply(block.data(), block.rows() * block.cols());
}
virtual void apply(Eigen::Block<Eigen::MatrixXf, -1, -1, true> block)
{
// Inner-panel blocks (e.g. leftCols()) are always contiguous for column-major matrices,
// but assert anyway for safety.
assert(block.outerStride() == block.rows());
apply(block.data(), block.rows() * block.cols());
}
virtual void apply(float* data, long size) {}
Expand Down Expand Up @@ -242,9 +245,7 @@ class ActivationReLU : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = relu(data[pos]);
}
}
};

Expand Down Expand Up @@ -309,9 +310,7 @@ class ActivationSigmoid : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = sigmoid(data[pos]);
}
}
};

Expand All @@ -321,9 +320,7 @@ class ActivationSwish : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = swish(data[pos]);
}
}
};

Expand All @@ -333,9 +330,7 @@ class ActivationHardSwish : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = hardswish(data[pos]);
}
}
};

Expand All @@ -345,9 +340,7 @@ class ActivationSoftsign : public Activation
void apply(float* data, long size) override
{
for (long pos = 0; pos < size; pos++)
{
data[pos] = softsign(data[pos]);
}
}
};

Expand All @@ -373,8 +366,8 @@ class FastLUTActivation : public Activation
// Fast lookup with linear interpolation
inline float lookup(float x) const
{
// Clamp input to range
x = std::clamp(x, min_x_, max_x_);
// Clamp input to range (inline to avoid header dependency)
x = x < min_x_ ? min_x_ : (x > max_x_ ? max_x_ : x);

// Calculate float index
float f_idx = (x - min_x_) * inv_step_;
Expand Down
Loading