Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ jobs:
uses: actions/checkout@v3

- name: Set Up the Python
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: 3.9

- name: Download distribution files
uses: actions/download-artifact@v2
uses: actions/download-artifact@v3
with:
name: dist
path: dist
4 changes: 2 additions & 2 deletions .github/workflows/build_whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
uses: actions/checkout@v3

- name: Login to DockerHub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
Expand All @@ -45,7 +45,7 @@ jobs:
docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done"

- name: Archive distribution files
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: dist
path: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Set Up the Python
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: 3.9

- name: Install twine
run: python -m pip install twine

- name: Download distribution files
uses: actions/download-artifact@v2
uses: actions/download-artifact@v3
with:
name: dist
path: dist
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ jobs:
uses: actions/checkout@v3

- name: Set Up the Python
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: 3.9

- name: Download distribution files
uses: actions/download-artifact@v2
uses: actions/download-artifact@v3
with:
name: dist
path: dist
Expand Down
50 changes: 50 additions & 0 deletions bmtrain/optim/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,53 @@ def adam_bf16(
bias_correction2,
stream,
)

def adam_fp32(
param_fp32: torch.Tensor,
g_fp32: torch.Tensor,
m_fp32: torch.Tensor,
v_fp32: torch.Tensor,
beta1: float,
beta2: float,
eps: float,
lr: float,
scale: float,
weight_decay: float,
step: int,
) -> None:
assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(g_fp32), "g_fp16 must be contiguous and on cuda"
assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
assert g_fp32.dtype == torch.float32, "g_fp16 must be float16 tensor"
assert m_fp32.dtype == torch.float32, "m_fp16 must be float16 tensor"
assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
assert (
param_fp32.numel() == g_fp32.numel()
), "param_fp32 and g_fp16 must have the same number of elements"
assert (
param_fp32.numel() == m_fp32.numel()
), "param_fp32 and m_fp32 must have the same number of elements"
assert (
param_fp32.numel() == v_fp32.numel()
), "param_fp32 and v_fp32 must have the same number of elements"
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
stream = torch.cuda.current_stream().cuda_stream
C.adam_fp32_launcher(
param_fp32.numel(),
param_fp32.data_ptr(),
g_fp32.data_ptr(),
m_fp32.data_ptr(),
v_fp32.data_ptr(),
beta1,
beta2,
eps,
lr,
scale,
weight_decay,
bias_correction1,
bias_correction2,
stream,
)
42 changes: 14 additions & 28 deletions bmtrain/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from itertools import chain
from collections import defaultdict


class AdamOptimizer(torch.optim.Optimizer):
"""
Adam optimizer support fp16 and bf16.
Expand Down Expand Up @@ -112,34 +111,21 @@ def step(self, closure=None, scale=1):
grad = p.grad

if p.dtype == torch.float32:
other_kwargs = {}
if (
"maximize"
in inspect.signature(
torch.optim._functional.adam
).parameters
):
other_kwargs["maximize"] = False
torch.optim._functional.adam(
[p],
[grad / scale],
[state["exp_avg"]],
[state["exp_avg_sq"]],
[],
(
[state["step"]]
if check_torch_version("1.12.0") < 0
else [torch.tensor(state["step"])]
),
amsgrad=False,
beta1=group["betas"][0],
beta2=group["betas"][1],
lr=0.0 if state["step"] < self._hold_steps else group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
**other_kwargs
f = F.adam_fp32
state["step"] += 1
f(
p, # fp32
grad, # fp32
state["exp_avg"], # fp32: m
state["exp_avg_sq"], # fp32: v
group["betas"][0],
group["betas"][1],
group["eps"],
0.0 if state["step"] < self._hold_steps else group["lr"],
scale,
group["weight_decay"],
state["step"],
)
state["step"] += 1
else:
f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16
state["step"] += 1
Expand Down
5 changes: 3 additions & 2 deletions csrc/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ PYBIND11_MODULE(C, m) {
m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported");
m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf");
m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16");
m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu");
m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu");
m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function");
m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function");
m.def("adam_fp32_launcher", &adam_fp32_launcher, "adam function");
m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu");
m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu");
m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward");
Expand Down
55 changes: 55 additions & 0 deletions csrc/cuda/adam_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include "bfloat16.cuh"
#include <stdio.h>

namespace {
// blocks <n // 1024>, threads<min(n, 1024)>
Expand Down Expand Up @@ -71,6 +72,35 @@ __global__ void adam_fp32_accum_bf16(
#endif
}

__global__ void adam_fp32_accum_fp32(
int32_t n,
const float *g, // (n)
float *m, // (n)
float *v, // (n)
float *param, // (n)
float beta1,
float beta2,
float eps,
float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2
) {
int32_t col = blockIdx.x * blockDim.x + threadIdx.x;
if (col < n) {
float local_g = g[col] / scale;
float local_m = beta1 * m[col] + (1 - beta1) * local_g;
float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g;
float local_p = param[col];
local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p;

param[col] = local_p;
v[col] = local_v;
m[col] = local_m;
}
}

}

void adam_fp16_launcher(
Expand Down Expand Up @@ -124,3 +154,28 @@ void adam_bf16_launcher(
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum_bf16<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}

void adam_fp32_launcher(
int n,
std::uintptr_t param_fp32,
std::uintptr_t g_fp32,
std::uintptr_t m_fp32,
std::uintptr_t v_fp32,
float beta1, float beta2,
float eps, float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2,
uintptr_t stream
) {
if (n <= 0) return;
auto g_ptr = reinterpret_cast<float*>(g_fp32);
auto m_ptr = reinterpret_cast<float*>(m_fp32);
auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32);
auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32);
int32_t threads = 1024;
dim3 block_size = dim3(threads, 1, 1);
dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
adam_fp32_accum_fp32<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
}
14 changes: 14 additions & 0 deletions csrc/include/bind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ void adam_bf16_launcher(
float bias_correction2,
uintptr_t stream
);
void adam_fp32_launcher(
int n,
std::uintptr_t param_fp32,
std::uintptr_t g_fp32,
std::uintptr_t m_fp32,
std::uintptr_t v_fp32,
float beta1, float beta2,
float eps, float lr,
float scale,
float weight_decay,
float bias_correction1,
float bias_correction2,
uintptr_t stream
);
9 changes: 7 additions & 2 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def main(dtype):
optim_manager.add_optimizer(opt4)
optim_manager.add_optimizer(opt5)

# fp16 bmt.adam
# fp16 bmt.Offload
# fp32 torch.adam
# fp32 bmt.adam
# fp32 bmt.Offload

for _ in range(100):
optim_manager.zero_grad()

for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()):
grad = torch.randn_like(p1)
p1.grad = grad.to(dtype)
Expand All @@ -81,7 +86,7 @@ def main(dtype):
assert_lt(diff1, 1)
assert_lt(diff2, 1)
assert_lt(diff3, 1)
assert_eq(diff4, 0)
assert_lt(diff4, 0.001)
assert_lt(diff5, 0.00001)

if __name__ == "__main__":
Expand Down