diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 65fe064..5b3bc7d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,12 +16,20 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: recursive - name: Install Rust uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy + - name: Set up CUDA + uses: Jimver/cuda-toolkit@v0.2.23 + with: + cuda: '12.8.0' + log-file-suffix: '-${{ github.job }}' + - name: Cache cargo uses: actions/cache@v4 with: @@ -31,7 +39,7 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ target/ - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-check-${{ hashFiles('**/Cargo.lock') }} - name: Check formatting run: cargo fmt --all -- --check @@ -44,10 +52,18 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: recursive - name: Install Rust uses: dtolnay/rust-toolchain@stable + - name: Set up CUDA + uses: Jimver/cuda-toolkit@v0.2.23 + with: + cuda: '12.8.0' + log-file-suffix: '-${{ github.job }}' + - name: Cache cargo uses: actions/cache@v4 with: @@ -57,7 +73,7 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ target/ - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.lock') }} - name: Run tests run: cargo test --all-features @@ -68,10 +84,18 @@ jobs: needs: [check, test] steps: - uses: actions/checkout@v4 + with: + submodules: recursive - name: Install Rust uses: dtolnay/rust-toolchain@stable + - name: Set up CUDA + uses: Jimver/cuda-toolkit@v0.2.23 + with: + cuda: '12.8.0' + log-file-suffix: '-${{ github.job }}' + - name: Cache cargo uses: actions/cache@v4 with: @@ -81,7 +105,7 @@ jobs: ~/.cargo/registry/cache/ ~/.cargo/git/db/ target/ - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-build-${{ hashFiles('**/Cargo.lock') }} - name: Build release run: cargo build --release @@ -89,5 +113,5 @@ jobs: - name: Upload binary uses: actions/upload-artifact@v4 with: - name: cubert + name: cubert-binary path: target/release/cubert diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..d2eaa78 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "dependencies/gem-blockset"] + path = dependencies/gem-blockset + url = https://github.com/Rosnaky/gem-blockset.git diff --git a/Cargo.lock b/Cargo.lock index 3b9cf62..4bd748f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -89,6 +98,26 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bindgen" +version = "0.72.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + [[package]] name = "bitflags" version = "2.11.0" @@ -143,6 +172,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -168,6 +206,17 @@ dependencies = [ "windows-link", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "cmake" version = "0.1.57" @@ -282,7 +331,9 @@ name = "cubert" version = "0.1.0" dependencies = [ "async-trait", + "bindgen", "chrono", + "cmake", "dotenv", "reqwest", "serde", @@ -575,6 +626,12 @@ dependencies = [ "wasip3", ] +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + [[package]] name = "h2" version = "0.4.13" @@ -921,6 +978,15 @@ dependencies = [ "serde", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.17" @@ -990,6 +1056,16 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + [[package]] name = "libm" version = "0.2.16" @@ -1074,6 +1150,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mio" version = "1.1.1" @@ -1085,6 +1167,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "num-bigint-dig" version = "0.8.6" @@ -1423,6 +1515,35 @@ dependencies = [ "bitflags", ] +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + [[package]] name = "reqwest" version = "0.13.2" diff --git a/Cargo.toml b/Cargo.toml index de8bfe6..dc04877 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,7 @@ uuid = { version = "1.21.0", features = ["v4"] } [dev-dependencies] tempfile = "3" tokio-test = "0.4" + +[build-dependencies] +bindgen = "0.72.1" +cmake = "0.1.57" diff --git a/build.rs b/build.rs index 28caf63..f0ebe10 100644 --- a/build.rs +++ b/build.rs @@ -1,11 +1,39 @@ +use std::env; +use std::path::PathBuf; + fn main() { - let libpath = format!("{}/model", env!("CARGO_MANIFEST_DIR")); + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let gem_dir = manifest_dir.join("dependencies/gem-blockset"); + + let cuda_path = env::var("CUDA_PATH") + .or_else(|_| env::var("CUDA_HOME")) + .unwrap_or_else(|_| "/opt/cuda".to_string()); + + let dst = cmake::Config::new(&gem_dir) + .define("CMAKE_CUDA_ARCHITECTURES", "75") + .define("CMAKE_CUDA_COMPILER", format!("{}/bin/nvcc", cuda_path)) + .define("GEM_BUILD_FFI", "ON") + .build_target("gem_blockset") + .build(); + + println!("cargo:rustc-link-search=native={}/build", dst.display()); + println!("cargo:rustc-link-lib=static=gem_blockset"); + + println!("cargo:rustc-link-search=native={}/lib64", cuda_path); + println!("cargo:rustc-link-search=native={}/lib", cuda_path); + println!("cargo:rustc-link-lib=dylib=cudart"); - println!("cargo:rustc-link-search=native={}", libpath); + println!("cargo:rustc-link-lib=dylib=stdc++"); - println!("cargo:rustc-link-lib=cudastats"); + let bindings = bindgen::Builder::default() + .header(gem_dir.join("models/ffi.h").to_str().unwrap()) + .generate() + .expect("failed to generate bindings"); - println!("cargo:rustc-link-arg=-Wl,-rpath,{}", libpath); + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("gem_bindings.rs")) + .expect("failed to write bindings"); - println!("cargo:rerun-if-changed=model/libcudastats.so"); + println!("cargo:rerun-if-changed=dependencies/gem-blockset/models"); } diff --git a/dependencies/gem-blockset b/dependencies/gem-blockset new file mode 160000 index 0000000..2678521 --- /dev/null +++ b/dependencies/gem-blockset @@ -0,0 +1 @@ +Subproject commit 2678521e627096db5231e6bdebd0113574f55a32 diff --git a/model/libcudastats.so b/model/libcudastats.so deleted file mode 100755 index 2dac267..0000000 Binary files a/model/libcudastats.so and /dev/null differ diff --git a/model/stats.cu b/model/stats.cu deleted file mode 100644 index b68d24d..0000000 --- a/model/stats.cu +++ /dev/null @@ -1,127 +0,0 @@ - -#include "stats.h" -#include -#include - -__device__ double warp_reduce_double(double val) { - for (int offset = 16; offset > 0; offset >>= 1) { - val += __shfl_down_sync(0xFFFFFFFF, val, offset); - } - return val; -} - -template -__global__ void stats_kernel( - const float* __restrict__ input, - double* __restrict__ sum_out, - double* __restrict__ sum_sq_out, - int n -) { - __shared__ double s_sum[BLOCK_SIZE/32]; - __shared__ double s_sum_sq[BLOCK_SIZE/32]; - - int tid = threadIdx.x; - int gid = blockIdx.x * BLOCK_SIZE * ELEMENTS_PER_THREAD + tid; - - double local_sum = 0, local_sum_sq = 0; - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_THREAD; i++) { - int idx = gid + i * BLOCK_SIZE; - if (idx < n) { - double val = input[idx]; - local_sum += val; - local_sum_sq += val * val; - } - } - - local_sum = warp_reduce_double(local_sum); - local_sum_sq = warp_reduce_double(local_sum_sq); - - int lane = tid%32; - int warpId = tid/32; - - if (lane == 0) { - s_sum[warpId] = local_sum; - s_sum_sq[warpId] = local_sum_sq; - } - - __syncthreads(); - - if (tid < BLOCK_SIZE/32) { - local_sum = s_sum[tid]; - local_sum_sq = s_sum_sq[tid]; - } - else { - local_sum = 0; - local_sum_sq = 0; - } - - if (warpId == 0) { - local_sum = warp_reduce_double(local_sum); - local_sum_sq = warp_reduce_double(local_sum_sq); - if (lane == 0) { - atomicAdd(sum_out, local_sum); - atomicAdd(sum_sq_out, local_sum_sq); - } - } -} - -extern "C" int compute_stats(const float* data, int n, StatsResult* result) { - const int BLOCK_SIZE = 256; - const int ELEMENTS_PER_THREAD = 8; - const int num_blocks = (n + BLOCK_SIZE * ELEMENTS_PER_THREAD - 1) / (BLOCK_SIZE * ELEMENTS_PER_THREAD); // ceiling division - - float* d_data; - double* d_sum; - double* d_sum_sq; - - if (cudaMalloc(&d_data, n*sizeof(float)) != cudaSuccess) return -1; - if (cudaMalloc(&d_sum, sizeof(double)) != cudaSuccess) { - cudaFree(d_data); - return -1; - } - if (cudaMalloc(&d_sum_sq, sizeof(double)) != cudaSuccess) { - cudaFree(d_data); - cudaFree(d_sum); - return -1; - } - - cudaMemcpy(d_data, data, n*sizeof(float), cudaMemcpyHostToDevice); - cudaMemset(d_sum, 0, sizeof(double)); - cudaMemset(d_sum_sq, 0, sizeof(double)); - - stats_kernel<<>>( - d_data, d_sum, d_sum_sq, n - ); - - double h_sum, h_sum_sq; - - cudaMemcpy(&h_sum, d_sum, sizeof(double), cudaMemcpyDeviceToHost); - cudaMemcpy(&h_sum_sq, d_sum_sq, sizeof(double), cudaMemcpyDeviceToHost); - - result->mean = h_sum / n; - result->variance = h_sum_sq/n - (result->mean * result->mean); - result->stddev = sqrt(result->variance); - - cudaFree(d_data); - cudaFree(d_sum); - cudaFree(d_sum_sq); - - return 0; -} - -extern "C" int compute_sharpe(const float* returns, int n, float risk_free_rate, double* sharpe) { - StatsResult stats; - int ret = compute_stats(returns, n, &stats); - if (ret) return ret; - - if (stats.stddev < 1e-10) { - *sharpe = 0; - } - else { - *sharpe = (stats.mean - risk_free_rate) / stats.stddev; - } - - return 0; -} diff --git a/model/stats.h b/model/stats.h deleted file mode 100644 index b486397..0000000 --- a/model/stats.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - double mean; - double variance; - double stddev; -} StatsResult; - -int compute_stats(const float* data, int n, StatsResult* result); -int compute_sharpe(const float* returns, int n, float risk_free_rate, double* sharpe); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/src/model/ffi.rs b/src/model/ffi.rs new file mode 100644 index 0000000..bf3e888 --- /dev/null +++ b/src/model/ffi.rs @@ -0,0 +1,83 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/gem_bindings.rs")); + +pub fn zscore(prices: &[f64], window: i32) -> Vec { + let mut out = vec![0.0; prices.len()]; + unsafe { + rolling_zscore_f64( + prices.as_ptr(), + out.as_mut_ptr(), + prices.len() as i32, + window, + ); + } + out +} + +pub fn zscore_f32(prices: &[f32], window: i32) -> Vec { + let mut out = vec![0.0; prices.len()]; + unsafe { + rolling_zscore_f32( + prices.as_ptr(), + out.as_mut_ptr(), + prices.len() as i32, + window, + ); + } + out +} + +pub fn ou_estimate(prices: &[f64], window: i32) -> (Vec, Vec, Vec) { + let n = prices.len(); + let mut speed = vec![0.0; n]; + let mut equilibrium = vec![0.0; n]; + let mut volatility_sq = vec![0.0; n]; + unsafe { + ou_estimation_f64( + prices.as_ptr(), + n as i32, + speed.as_mut_ptr(), + equilibrium.as_mut_ptr(), + volatility_sq.as_mut_ptr(), + window, + ); + } + (speed, equilibrium, volatility_sq) +} + +pub fn ou_estimate_f32(prices: &[f32], window: i32) -> (Vec, Vec, Vec) { + let n = prices.len(); + let mut speed = vec![0.0; n]; + let mut equilibrium = vec![0.0; n]; + let mut volatility_sq = vec![0.0; n]; + unsafe { + ou_estimation_f32( + prices.as_ptr(), + n as i32, + speed.as_mut_ptr(), + equilibrium.as_mut_ptr(), + volatility_sq.as_mut_ptr(), + window, + ); + } + (speed, equilibrium, volatility_sq) +} + +pub fn adf(prices: &[f64], lags: i32) -> ADFResult_f64 { + let mut result = unsafe { std::mem::zeroed::() }; + unsafe { + adf_test_f64(prices.as_ptr(), prices.len() as i32, lags, &mut result); + } + result +} + +pub fn adf_f32(prices: &[f32], lags: i32) -> ADFResult_f32 { + let mut result = unsafe { std::mem::zeroed::() }; + unsafe { + adf_test_f32(prices.as_ptr(), prices.len() as i32, lags, &mut result); + } + result +} diff --git a/src/model/mod.rs b/src/model/mod.rs index 001a1a8..d8a989e 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,79 +1 @@ -use std::ffi::{c_float, c_int}; - -#[repr(C)] -pub struct StatsResult { - pub mean: f64, - pub variance: f64, - pub stddev: f64, -} - -#[link(name = "cudastats")] -unsafe extern "C" { - fn compute_stats(data: *const f32, n: c_int, result: *mut StatsResult) -> c_int; - fn compute_sharpe( - returns: *const f32, - n: c_int, - risk_free_rate: f32, - sharpe: *mut f64, - ) -> c_int; -} - -#[derive(Debug)] -pub enum ModelError { - ModelErrorFail, - ModelErrorInvalidInput, -} - -impl std::fmt::Display for ModelError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ModelError::ModelErrorFail => write!(f, "Model execution failed"), - ModelError::ModelErrorInvalidInput => write!(f, "Invalid input data"), - } - } -} - -impl std::error::Error for ModelError {} - -pub fn stats(data: &[f32]) -> Result { - if data.is_empty() { - return Err(ModelError::ModelErrorInvalidInput); - } - - let mut result = StatsResult { - mean: 0.0, - variance: 0.0, - stddev: 0.0, - }; - - let err = unsafe { compute_stats(data.as_ptr(), data.len() as c_int, &mut result) }; - - if err != 0 { - return Err(ModelError::ModelErrorFail); - } - - Ok(result) -} - -pub fn sharpe(returns: &[f32], risk_free_rate: f32) -> Result { - if returns.is_empty() { - return Err(ModelError::ModelErrorInvalidInput); - } - - let mut result = 0.0; - - let err = unsafe { - compute_sharpe( - returns.as_ptr(), - returns.len() as c_int, - risk_free_rate as c_float, - &mut result, - ) - }; - - if err != 0 { - return Err(ModelError::ModelErrorFail); - } - - Ok(result) -} +pub mod ffi;