From 3c5a213a8e334fdcb1babb25a1ba91fbace92dc1 Mon Sep 17 00:00:00 2001 From: mattcieslak Date: Mon, 9 Feb 2026 16:48:45 -0500 Subject: [PATCH 1/3] add tools for benchmarking --- CMakeLists.txt | 16 + bench/CMakeLists.txt | 3 + bench/bench_trx_stream.cpp | 1880 ++++++++++++++++++++++++++++++ bench/plot_bench.py | 214 ++++ docs/_static/benchmarks/.gitkeep | 1 + docs/benchmarks.rst | 118 ++ docs/index.rst | 1 + docs/usage.rst | 168 +++ include/trx/trx.h | 151 ++- include/trx/trx.tpp | 634 +++++++++- src/trx.cpp | 284 ++++- tests/test_trx_mmap.cpp | 2 +- tests/test_trx_trxfile.cpp | 39 + 13 files changed, 3437 insertions(+), 74 deletions(-) create mode 100644 bench/CMakeLists.txt create mode 100644 bench/bench_trx_stream.cpp create mode 100644 bench/plot_bench.py create mode 100644 docs/_static/benchmarks/.gitkeep create mode 100644 docs/benchmarks.rst diff --git a/CMakeLists.txt b/CMakeLists.txt index 2dc74c4..bf07eac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,7 @@ endif() option(TRX_USE_CONAN "Should Conan package manager be used?" OFF) option(TRX_BUILD_TESTS "Build trx tests" OFF) option(TRX_BUILD_EXAMPLES "Build trx example commandline programs" ON) +option(TRX_BUILD_BENCHMARKS "Build trx benchmarks" OFF) option(TRX_ENABLE_CLANG_TIDY "Run clang-tidy during builds" OFF) option(TRX_ENABLE_INSTALL "Install trx-cpp targets" ${TRX_IS_TOP_LEVEL}) option(TRX_BUILD_DOCS "Build API documentation with Doxygen/Sphinx" OFF) @@ -148,6 +149,21 @@ if(TRX_BUILD_TESTS) endif() endif() +if(TRX_BUILD_BENCHMARKS) + find_package(benchmark CONFIG QUIET) + if(NOT benchmark_FOUND) + set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "Disable benchmark tests" FORCE) + set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "Disable benchmark gtest" FORCE) + FetchContent_Declare( + benchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG v1.8.3 + ) + FetchContent_MakeAvailable(benchmark) + endif() + add_subdirectory(bench) +endif() + if(TRX_ENABLE_NIFTI) find_package(ZLIB REQUIRED) add_library(trx-nifti diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt new file mode 100644 index 0000000..b493d34 --- /dev/null +++ b/bench/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(bench_trx_stream bench_trx_stream.cpp) +target_link_libraries(bench_trx_stream PRIVATE trx benchmark::benchmark) +target_compile_features(bench_trx_stream PRIVATE cxx_std_17) diff --git a/bench/bench_trx_stream.cpp b/bench/bench_trx_stream.cpp new file mode 100644 index 0000000..e664481 --- /dev/null +++ b/bench/bench_trx_stream.cpp @@ -0,0 +1,1880 @@ +// Benchmark TRX streaming workloads for realistic datasets. +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__unix__) || defined(__APPLE__) +#include +#include +#include +#endif + +#include + +namespace { +using Eigen::half; + +constexpr float kMinLengthMm = 20.0f; +constexpr float kMaxLengthMm = 500.0f; +constexpr float kStepMm = 2.0f; +constexpr float kCurvatureSigma = 0.08f; +constexpr float kSlabThicknessMm = 5.0f; +constexpr size_t kSlabCount = 20; + +constexpr std::array kStreamlineCounts = {100000, 500000, 1000000, 5000000, 10000000}; + +struct Fov { + float min_x; + float max_x; + float min_y; + float max_y; + float min_z; + float max_z; +}; + +constexpr Fov kFov{-70.0f, 70.0f, -108.0f, 79.0f, -60.0f, 75.0f}; +constexpr float kRandomMinMm = 10.0f; +constexpr float kRandomMaxMm = 400.0f; + +enum class GroupScenario : int { None = 0, Bundles = 1, Connectome = 2 }; +enum class LengthProfile : int { Mixed = 0, Short = 1, Medium = 2, Long = 3 }; + +constexpr size_t kBundleCount = 80; +constexpr size_t kConnectomeRegions = 100; + +std::string make_temp_path(const std::string &prefix) { + static std::atomic counter{0}; + const auto id = counter.fetch_add(1, std::memory_order_relaxed); + const auto dir = std::filesystem::temp_directory_path(); + return (dir / (prefix + "_" + std::to_string(id) + ".trx")).string(); +} + +std::string make_temp_dir_name(const std::string &prefix) { + static std::atomic counter{0}; + const auto id = counter.fetch_add(1, std::memory_order_relaxed); + const auto dir = std::filesystem::temp_directory_path(); + return (dir / (prefix + "_" + std::to_string(id))).string(); +} + +std::string make_work_dir_name(const std::string &prefix) { + static std::atomic counter{0}; + const auto id = counter.fetch_add(1, std::memory_order_relaxed); +#if defined(__unix__) || defined(__APPLE__) + const auto pid = static_cast(getpid()); +#else + const auto pid = static_cast(0); +#endif + const auto dir = std::filesystem::current_path(); + return (dir / (prefix + "_" + std::to_string(pid) + "_" + std::to_string(id))).string(); +} + +std::string make_status_path(const std::string &prefix) { + static std::atomic counter{0}; + const auto id = counter.fetch_add(1, std::memory_order_relaxed); + const auto dir = std::filesystem::temp_directory_path(); + return (dir / (prefix + "_" + std::to_string(id) + ".txt")).string(); +} + +std::string make_temp_dir_path(const std::string &prefix) { + return trx::make_temp_dir(prefix); +} + +void register_cleanup(const std::string &path); +std::vector list_files(const std::string &dir); + +std::string find_file_by_prefix(const std::string &dir, const std::string &prefix) { + std::error_code ec; + for (const auto &entry : trx::fs::directory_iterator(dir, ec)) { + if (ec) { + break; + } + if (!entry.is_regular_file()) { + continue; + } + const auto filename = entry.path().filename().string(); + if (filename.rfind(prefix, 0) == 0) { + return entry.path().string(); + } + } + return ""; +} + +std::vector list_files(const std::string &dir) { + std::vector files; + std::error_code ec; + if (!trx::fs::exists(dir, ec)) { + return files; + } + for (const auto &entry : trx::fs::directory_iterator(dir, ec)) { + if (ec) { + break; + } + if (!entry.is_regular_file()) { + continue; + } + files.push_back(entry.path().filename().string()); + } + std::sort(files.begin(), files.end()); + return files; +} + +size_t file_size_bytes(const std::string &path) { + std::error_code ec; + if (!trx::fs::exists(path, ec)) { + return 0; + } + if (trx::fs::is_directory(path, ec)) { + size_t total = 0; + for (trx::fs::recursive_directory_iterator it(path, ec), end; it != end; it.increment(ec)) { + if (ec) { + break; + } + if (!it->is_regular_file(ec)) { + continue; + } + total += static_cast(trx::fs::file_size(it->path(), ec)); + if (ec) { + break; + } + } + return total; + } + return static_cast(trx::fs::file_size(path, ec)); +} + +void wait_for_shard_ok(const std::vector &shard_paths, + const std::vector &status_paths, + size_t timeout_ms) { + const auto start = std::chrono::steady_clock::now(); + while (true) { + bool all_ok = true; + for (size_t i = 0; i < shard_paths.size(); ++i) { + const auto ok_path = trx::fs::path(shard_paths[i]) / "SHARD_OK"; + std::error_code ec; + if (!trx::fs::exists(ok_path, ec)) { + all_ok = false; + break; + } + } + if (all_ok) { + return; + } + const auto now = std::chrono::steady_clock::now(); + const auto elapsed_ms = + std::chrono::duration_cast(now - start).count(); + if (elapsed_ms > static_cast(timeout_ms)) { + std::string detail = "Timed out waiting for SHARD_OK"; + for (size_t i = 0; i < status_paths.size(); ++i) { + std::ifstream in(status_paths[i]); + std::string line; + if (in.is_open()) { + std::getline(in, line); + } + if (!line.empty()) { + detail += " shard_" + std::to_string(i) + "=" + line; + } + } + throw std::runtime_error(detail); + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } +} + +void copy_file_append(const std::string &src, const std::string &dst, std::size_t buffer_bytes = 8 * 1024 * 1024) { + std::ifstream in(src, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open file for read: " + src); + } + std::ofstream out(dst, std::ios::binary | std::ios::out | std::ios::app); + if (!out.is_open()) { + throw std::runtime_error("Failed to open file for append: " + dst); + } + std::vector buffer(buffer_bytes); + while (in) { + in.read(buffer.data(), static_cast(buffer.size())); + const std::streamsize count = in.gcount(); + if (count > 0) { + out.write(buffer.data(), count); + } + } +} + +std::pair read_header_counts(const std::string &dir) { + const auto header_path = trx::fs::path(dir) / "header.json"; + std::ifstream in; + for (int attempt = 0; attempt < 5; ++attempt) { + in.open(header_path); + if (in.is_open()) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + if (!in.is_open()) { + std::error_code ec; + const bool exists = trx::fs::exists(dir, ec); + const auto files = list_files(dir); + const int open_err = errno; + std::string detail = "Failed to open header.json at: " + header_path.string(); + detail += " exists=" + std::string(exists ? "true" : "false"); + detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err)); + if (!files.empty()) { + detail += " files=["; + for (size_t i = 0; i < files.size(); ++i) { + if (i > 0) { + detail += ","; + } + detail += files[i]; + } + detail += "]"; + } + throw std::runtime_error(detail); + } + std::string contents((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + std::string err; + const auto header = json::parse(contents, err); + if (!err.empty()) { + throw std::runtime_error("Failed to parse header.json: " + err); + } + const auto nb_streamlines = static_cast(header["NB_STREAMLINES"].int_value()); + const auto nb_vertices = static_cast(header["NB_VERTICES"].int_value()); + return {nb_streamlines, nb_vertices}; +} + +json read_header_json(const std::string &dir) { + const auto header_path = trx::fs::path(dir) / "header.json"; + std::ifstream in; + for (int attempt = 0; attempt < 5; ++attempt) { + in.open(header_path); + if (in.is_open()) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + if (!in.is_open()) { + std::error_code ec; + const bool exists = trx::fs::exists(dir, ec); + const auto files = list_files(dir); + const int open_err = errno; + std::string detail = "Failed to open header.json at: " + header_path.string(); + detail += " exists=" + std::string(exists ? "true" : "false"); + detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err)); + if (!files.empty()) { + detail += " files=["; + for (size_t i = 0; i < files.size(); ++i) { + if (i > 0) { + detail += ","; + } + detail += files[i]; + } + detail += "]"; + } + throw std::runtime_error(detail); + } + std::string contents((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + std::string err; + const auto header = json::parse(contents, err); + if (!err.empty()) { + throw std::runtime_error("Failed to parse header.json: " + err); + } + return header; +} + +double get_max_rss_kb() { +#if defined(__unix__) || defined(__APPLE__) + rusage usage{}; + if (getrusage(RUSAGE_SELF, &usage) != 0) { + return 0.0; + } +#if defined(__APPLE__) + return static_cast(usage.ru_maxrss) / 1024.0; +#else + return static_cast(usage.ru_maxrss); +#endif +#else + return 0.0; +#endif +} + +size_t parse_env_size(const char *name, size_t default_value) { + const char *raw = std::getenv(name); + if (!raw || raw[0] == '\0') { + return default_value; + } + char *end = nullptr; + const unsigned long long value = std::strtoull(raw, &end, 10); + if (end == raw) { + return default_value; + } + return static_cast(value); +} + +bool parse_env_bool(const char *name, bool default_value) { + const char *raw = std::getenv(name); + if (!raw || raw[0] == '\0') { + return default_value; + } + return std::string(raw) != "0"; +} + +int parse_env_int(const char *name, int default_value) { + const char *raw = std::getenv(name); + if (!raw || raw[0] == '\0') { + return default_value; + } + char *end = nullptr; + const long value = std::strtol(raw, &end, 10); + if (end == raw) { + return default_value; + } + return static_cast(value); +} + +size_t group_count_for(GroupScenario scenario) { + switch (scenario) { + case GroupScenario::Bundles: + return kBundleCount; + case GroupScenario::Connectome: + return (kConnectomeRegions * (kConnectomeRegions - 1)) / 2; + case GroupScenario::None: + default: + return 0; + } +} + +std::size_t buffer_bytes_for_streamlines(std::size_t streamlines) { + if (streamlines >= 5000000) { + return 2ULL * 1024ULL * 1024ULL * 1024ULL; + } + if (streamlines >= 1000000) { + return 256ULL * 1024ULL * 1024ULL; + } + return 16ULL * 1024ULL * 1024ULL; +} + +std::vector streamlines_for_benchmarks() { + const size_t only = parse_env_size("TRX_BENCH_ONLY_STREAMLINES", 0); + if (only > 0) { + return {only}; + } + const size_t max_val = parse_env_size("TRX_BENCH_MAX_STREAMLINES", 10000000); + std::vector counts = {10000000, 5000000, 1000000, 500000, 100000}; + counts.erase(std::remove_if(counts.begin(), counts.end(), [&](size_t v) { return v > max_val; }), counts.end()); + if (counts.empty()) { + counts.push_back(max_val); + } + return counts; +} + +void log_bench_start(const std::string &name, const std::string &details) { + if (!parse_env_bool("TRX_BENCH_LOG", false)) { + return; + } + std::cerr << "[trx-bench] start " << name << " " << details << std::endl; +} + +void log_bench_end(const std::string &name, const std::string &details) { + if (!parse_env_bool("TRX_BENCH_LOG", false)) { + return; + } + std::cerr << "[trx-bench] end " << name << " " << details << std::endl; +} + +void log_bench_config(const std::string &name, size_t threads, size_t batch_size) { + if (!parse_env_bool("TRX_BENCH_LOG", false)) { + return; + } + std::cerr << "[trx-bench] config " << name << " threads=" << threads << " batch=" << batch_size << std::endl; +} + +const std::vector &group_names_for(GroupScenario scenario) { + static const std::vector empty; + static const std::vector bundle_names = []() { + std::vector names; + names.reserve(kBundleCount); + for (size_t i = 1; i <= kBundleCount; ++i) { + names.push_back("Bundle" + std::to_string(i)); + } + return names; + }(); + static const std::vector connectome_names = []() { + std::vector names; + names.reserve((kConnectomeRegions * (kConnectomeRegions - 1)) / 2); + for (size_t i = 1; i <= kConnectomeRegions; ++i) { + for (size_t j = i + 1; j <= kConnectomeRegions; ++j) { + names.push_back("conn_" + std::to_string(i) + "_" + std::to_string(j)); + } + } + return names; + }(); + + switch (scenario) { + case GroupScenario::Bundles: + return bundle_names; + case GroupScenario::Connectome: + return connectome_names; + case GroupScenario::None: + default: + return empty; + } +} + +float sample_length_mm(std::mt19937 &rng, LengthProfile profile) { + auto sample_uniform = [&](float min_val, float max_val) { + std::uniform_real_distribution dist(min_val, max_val); + return dist(rng); + }; + switch (profile) { + case LengthProfile::Short: + return sample_uniform(20.0f, 120.0f); + case LengthProfile::Medium: + return sample_uniform(80.0f, 260.0f); + case LengthProfile::Long: + return sample_uniform(200.0f, 500.0f); + case LengthProfile::Mixed: + default: + return sample_uniform(kMinLengthMm, kMaxLengthMm); + } +} + +size_t estimate_points_per_streamline(LengthProfile profile) { + float mean_length = 0.0f; + switch (profile) { + case LengthProfile::Short: + mean_length = 70.0f; + break; + case LengthProfile::Medium: + mean_length = 170.0f; + break; + case LengthProfile::Long: + mean_length = 350.0f; + break; + case LengthProfile::Mixed: + default: + mean_length = 260.0f; + break; + } + return static_cast(std::ceil(mean_length / kStepMm)) + 1; +} + +std::array random_unit_vector(std::mt19937 &rng) { + std::normal_distribution dist(0.0f, 1.0f); + std::array v{dist(rng), dist(rng), dist(rng)}; + const float norm = std::sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + if (norm < 1e-6f) { + return {1.0f, 0.0f, 0.0f}; + } + v[0] /= norm; + v[1] /= norm; + v[2] /= norm; + return v; +} + +std::vector> generate_streamline_points(std::mt19937 &rng, LengthProfile profile) { + const float length_mm = sample_length_mm(rng, profile); + const size_t point_count = std::max(2, static_cast(std::ceil(length_mm / kStepMm)) + 1); + std::vector> points; + points.reserve(point_count); + + std::uniform_real_distribution dist_x(kRandomMinMm, kRandomMaxMm); + std::uniform_real_distribution dist_y(kRandomMinMm, kRandomMaxMm); + std::uniform_real_distribution dist_z(kRandomMinMm, kRandomMaxMm); + + for (size_t i = 0; i < point_count; ++i) { + points.push_back({dist_x(rng), dist_y(rng), dist_z(rng)}); + } + + return points; +} + +std::vector> generate_streamline_points_seeded(uint32_t seed, LengthProfile profile) { + std::mt19937 rng(seed); + return generate_streamline_points(rng, profile); +} + +size_t bench_threads() { + const size_t requested = parse_env_size("TRX_BENCH_THREADS", 0); + if (requested > 0) { + return requested; + } + const unsigned int hc = std::thread::hardware_concurrency(); + return hc == 0 ? 1U : static_cast(hc); +} + +size_t bench_batch_size() { + return parse_env_size("TRX_BENCH_BATCH", 1000); +} + +template +void generate_streamlines_parallel(size_t streamlines, + LengthProfile profile, + size_t threads, + size_t batch_size, + uint32_t base_seed, + BatchConsumer consumer) { + const size_t total_batches = (streamlines + batch_size - 1) / batch_size; + std::atomic next_batch{0}; + std::mutex mutex; + std::condition_variable cv; + std::map>>> completed; + std::condition_variable cv_producer; + size_t inflight_batches = 0; + const size_t max_inflight = std::max(1, parse_env_size("TRX_BENCH_QUEUE_MAX", 8)); + + auto worker = [&]() { + for (;;) { + size_t batch_idx; + { + // Wait for queue space BEFORE grabbing batch index to avoid missed notifications + std::unique_lock lock(mutex); + cv_producer.wait(lock, [&]() { return inflight_batches < max_inflight || next_batch.load() >= total_batches; }); + batch_idx = next_batch.fetch_add(1); + if (batch_idx >= total_batches) { + return; + } + ++inflight_batches; + } + const size_t start = batch_idx * batch_size; + const size_t count = std::min(batch_size, streamlines - start); + std::vector>> batch; + batch.reserve(count); + for (size_t i = 0; i < count; ++i) { + const uint32_t seed = base_seed + static_cast(start + i); + batch.push_back(generate_streamline_points_seeded(seed, profile)); + } + { + std::lock_guard lock(mutex); + completed.emplace(batch_idx, std::move(batch)); + } + cv.notify_one(); + } + }; + + std::vector workers; + workers.reserve(threads); + for (size_t t = 0; t < threads; ++t) { + workers.emplace_back(worker); + } + + for (size_t expected = 0; expected < total_batches; ++expected) { + std::unique_lock lock(mutex); + cv.wait(lock, [&]() { return completed.find(expected) != completed.end(); }); + auto batch = std::move(completed[expected]); + completed.erase(expected); + if (inflight_batches > 0) { + --inflight_batches; + } + lock.unlock(); + cv_producer.notify_all(); // Wake all waiting workers, not just one, to avoid deadlock + + const size_t start = expected * batch_size; + consumer(start, batch); + } + + for (auto &worker_thread : workers) { + worker_thread.join(); + } +} + +struct TrxWriteStats { + double write_ms = 0.0; + double file_size_bytes = 0.0; +}; + +struct RssSample { + double elapsed_ms = 0.0; + double rss_kb = 0.0; + std::string phase; +}; + +struct FileSizeScenario { + size_t streamlines = 0; + LengthProfile profile = LengthProfile::Mixed; + bool add_dps = false; + bool add_dpv = false; + zip_uint32_t compression = ZIP_CM_STORE; +}; + +std::mutex g_rss_samples_mutex; + +void append_rss_samples(const FileSizeScenario &scenario, const std::vector &samples) { + if (samples.empty()) { + return; + } + const char *path = std::getenv("TRX_RSS_SAMPLES_PATH"); + if (!path || path[0] == '\0') { + return; + } + std::lock_guard lock(g_rss_samples_mutex); + std::ofstream out(path, std::ios::app); + if (!out.is_open()) { + return; + } + + out << "{" + << "\"streamlines\":" << scenario.streamlines << "," + << "\"length_profile\":" << static_cast(scenario.profile) << "," + << "\"dps\":" << (scenario.add_dps ? 1 : 0) << "," + << "\"dpv\":" << (scenario.add_dpv ? 1 : 0) << "," + << "\"compression\":" << (scenario.compression == ZIP_CM_DEFLATE ? 1 : 0) << "," + << "\"samples\":["; + for (size_t i = 0; i < samples.size(); ++i) { + if (i > 0) { + out << ","; + } + out << "{" + << "\"elapsed_ms\":" << samples[i].elapsed_ms << "," + << "\"rss_kb\":" << samples[i].rss_kb << "," + << "\"phase\":\"" << samples[i].phase << "\"" + << "}"; + } + out << "]}\n"; +} + +std::mutex g_cleanup_mutex; +std::vector g_cleanup_paths; +pid_t g_cleanup_owner_pid = 0; +bool g_cleanup_only_on_success = true; +bool g_run_success = false; + +void cleanup_temp_paths() { + if (g_cleanup_only_on_success && !g_run_success) { + return; + } + if (g_cleanup_owner_pid != 0 && getpid() != g_cleanup_owner_pid) { + return; + } + std::error_code ec; + for (const auto &p : g_cleanup_paths) { + std::filesystem::remove_all(p, ec); + } +} + +void register_cleanup(const std::string &path) { + static bool registered = false; + { + std::lock_guard lock(g_cleanup_mutex); + if (g_cleanup_owner_pid == 0) { + g_cleanup_owner_pid = getpid(); + } + g_cleanup_paths.push_back(path); + } + if (!registered) { + registered = true; + std::atexit(cleanup_temp_paths); + } +} + +TrxWriteStats run_trx_file_size(size_t streamlines, + LengthProfile profile, + bool add_dps, + bool add_dpv, + zip_uint32_t compression) { + trx::TrxStream stream("float16"); + stream.set_metadata_mode(trx::TrxStream::MetadataMode::OnDisk); + stream.set_metadata_buffer_max_bytes(64ULL * 1024ULL * 1024ULL); + stream.set_positions_buffer_max_bytes(buffer_bytes_for_streamlines(streamlines)); + + const size_t threads = bench_threads(); + const size_t batch_size = std::max(1, bench_batch_size()); + const uint32_t base_seed = static_cast(1337 + streamlines + static_cast(profile) * 13); + const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); + log_bench_config("file_size_generate", threads, batch_size); + + const bool collect_rss = std::getenv("TRX_RSS_SAMPLES_PATH") != nullptr; + const size_t sample_every = parse_env_size("TRX_RSS_SAMPLE_EVERY", 50000); + const int sample_interval_ms = parse_env_int("TRX_RSS_SAMPLE_MS", 500); + std::vector samples; + std::mutex samples_mutex; + const auto bench_start = std::chrono::steady_clock::now(); + auto record_sample = [&](const std::string &phase) { + if (!collect_rss) { + return; + } + const auto now = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = now - bench_start; + std::lock_guard lock(samples_mutex); + samples.push_back({elapsed.count(), get_max_rss_kb(), phase}); + }; + + std::vector dps; + std::vector dpv; + if (add_dps) { + dps.reserve(streamlines); + } + if (add_dpv) { + const size_t estimated_vertices = streamlines * estimate_points_per_streamline(profile); + dpv.reserve(estimated_vertices); + } + + generate_streamlines_parallel( + streamlines, + profile, + threads, + batch_size, + base_seed, + [&](size_t start, const std::vector>> &batch) { + if (parse_env_bool("TRX_BENCH_LOG", false)) { + std::cerr << "[trx-bench] batch file_size start=" << start << " count=" << batch.size() << std::endl; + } + for (size_t i = 0; i < batch.size(); ++i) { + const auto &points = batch[i]; + stream.push_streamline(points); + if (add_dps) { + dps.push_back(1.0f); + } + if (add_dpv) { + dpv.insert(dpv.end(), points.size(), 0.5f); + } + const size_t global_idx = start + i + 1; + if (progress_every > 0 && (global_idx % progress_every == 0)) { + std::cerr << "[trx-bench] progress file_size streamlines=" << global_idx << " / " << streamlines + << std::endl; + } + if (collect_rss && sample_every > 0 && (global_idx % sample_every == 0)) { + record_sample("generate"); + } + } + }); + + if (add_dps) { + stream.push_dps_from_vector("dps_scalar", "float32", dps); + } + if (add_dpv) { + stream.push_dpv_from_vector("dpv_scalar", "float32", dpv); + } + + const std::string out_path = make_temp_path("trx_size"); + record_sample("before_finalize"); + + std::atomic sampling{false}; + std::thread sampler; + if (collect_rss) { + sampling.store(true, std::memory_order_relaxed); + sampler = std::thread([&]() { + while (sampling.load(std::memory_order_relaxed)) { + record_sample("finalize"); + std::this_thread::sleep_for(std::chrono::milliseconds(sample_interval_ms)); + } + }); + } + + const auto start = std::chrono::steady_clock::now(); + stream.finalize(out_path, compression); + const auto end = std::chrono::steady_clock::now(); + + if (collect_rss) { + sampling.store(false, std::memory_order_relaxed); + if (sampler.joinable()) { + sampler.join(); + } + } + record_sample("after_finalize"); + + TrxWriteStats stats; + stats.write_ms = std::chrono::duration(end - start).count(); + std::error_code size_ec; + const auto size = std::filesystem::file_size(out_path, size_ec); + stats.file_size_bytes = size_ec ? 0.0 : static_cast(size); + std::error_code ec; + std::filesystem::remove(out_path, ec); + + if (collect_rss) { + FileSizeScenario scenario; + scenario.streamlines = streamlines; + scenario.profile = profile; + scenario.add_dps = add_dps; + scenario.add_dpv = add_dpv; + scenario.compression = compression; + append_rss_samples(scenario, samples); + } + return stats; +} + +struct TrxOnDisk { + std::string path; + size_t streamlines = 0; + size_t vertices = 0; + double shard_merge_ms = 0.0; + size_t shard_processes = 1; +}; + +TrxOnDisk build_trx_file_on_disk_single(size_t streamlines, + GroupScenario scenario, + bool add_dps, + bool add_dpv, + LengthProfile profile, + zip_uint32_t compression, + const std::string &out_path_override = "", + bool finalize_to_directory = false) { + trx::TrxStream stream("float16"); + stream.set_metadata_mode(trx::TrxStream::MetadataMode::OnDisk); + stream.set_metadata_buffer_max_bytes(64ULL * 1024ULL * 1024ULL); + stream.set_positions_buffer_max_bytes(buffer_bytes_for_streamlines(streamlines)); + const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); + + const auto group_count = group_count_for(scenario); + const auto &group_names = group_names_for(scenario); + std::vector> groups(group_count); + + const size_t threads = bench_threads(); + const size_t batch_size = std::max(1, bench_batch_size()); + const uint32_t base_seed = static_cast(1337 + streamlines + static_cast(scenario) * 31); + log_bench_config("build_trx_generate", threads, batch_size); + + std::vector dps; + std::vector dpv; + if (add_dps) { + dps.reserve(streamlines); + } + if (add_dpv) { + const size_t estimated_vertices = streamlines * estimate_points_per_streamline(profile); + dpv.reserve(estimated_vertices); + } + + size_t total_vertices = 0; + generate_streamlines_parallel( + streamlines, + profile, + threads, + batch_size, + base_seed, + [&](size_t start, const std::vector>> &batch) { + if (parse_env_bool("TRX_BENCH_LOG", false)) { + std::cerr << "[trx-bench] batch build_trx start=" << start << " count=" << batch.size() << std::endl; + } + for (size_t i = 0; i < batch.size(); ++i) { + const auto &points = batch[i]; + total_vertices += points.size(); + stream.push_streamline(points); + if (add_dps) { + dps.push_back(1.0f); + } + if (add_dpv) { + dpv.insert(dpv.end(), points.size(), 0.5f); + } + const size_t global_idx = start + i; + if (group_count > 0) { + groups[global_idx % group_count].push_back(static_cast(global_idx)); + } + if (progress_every > 0 && ((global_idx + 1) % progress_every == 0)) { + if (parse_env_bool("TRX_BENCH_CHILD_LOG", false) || parse_env_bool("TRX_BENCH_LOG", false)) { + const char *shard_env = std::getenv("TRX_BENCH_SHARD_INDEX"); + const std::string shard_prefix = shard_env ? std::string(" shard=") + shard_env : ""; + std::cerr << "[trx-bench] progress build_trx" << shard_prefix << " streamlines=" << (global_idx + 1) + << " / " << streamlines << std::endl; + } + } + } + }); + + if (add_dps) { + stream.push_dps_from_vector("dps_scalar", "float32", dps); + } + if (add_dpv) { + stream.push_dpv_from_vector("dpv_scalar", "float32", dpv); + } + if (group_count > 0) { + for (size_t g = 0; g < group_count; ++g) { + stream.push_group_from_indices(group_names[g], groups[g]); + } + } + + const std::string out_path = out_path_override.empty() ? make_temp_path("trx_input") : out_path_override; + if (finalize_to_directory) { + // Use persistent variant to avoid removing pre-created shard directories + stream.finalize_directory_persistent(out_path); + } else { + stream.finalize(out_path, compression); + } + if (out_path_override.empty() && !finalize_to_directory) { + register_cleanup(out_path); + } + return {out_path, streamlines, total_vertices, 0.0, 1}; +} + +void build_trx_shard(const std::string &out_path, + size_t streamlines, + GroupScenario scenario, + bool add_dps, + bool add_dpv, + LengthProfile profile, + zip_uint32_t compression) { + (void)build_trx_file_on_disk_single(streamlines, + scenario, + add_dps, + add_dpv, + profile, + compression, + out_path, + true); + + // Defensive validation: ensure all required files were written by finalize_directory_persistent + std::error_code ec; + const auto header_path = trx::fs::path(out_path) / "header.json"; + if (!trx::fs::exists(header_path, ec)) { + throw std::runtime_error("Shard missing header.json after finalize_directory_persistent: " + header_path.string()); + } + const auto positions_path = find_file_by_prefix(out_path, "positions."); + if (positions_path.empty()) { + throw std::runtime_error("Shard missing positions after finalize_directory_persistent: " + out_path); + } + const auto offsets_path = find_file_by_prefix(out_path, "offsets."); + if (offsets_path.empty()) { + throw std::runtime_error("Shard missing offsets after finalize_directory_persistent: " + out_path); + } + const auto ok_path = trx::fs::path(out_path) / "SHARD_OK"; + std::ofstream ok(ok_path, std::ios::out | std::ios::trunc); + if (ok.is_open()) { + ok << "ok\n"; + ok.flush(); + ok.close(); + } + + // Force filesystem sync to ensure all shard data is visible to parent process +#if defined(__unix__) || defined(__APPLE__) + sync(); + // Brief sleep to ensure filesystem metadata updates are visible across processes + std::this_thread::sleep_for(std::chrono::milliseconds(50)); +#endif +} + +TrxOnDisk build_trx_file_on_disk(size_t streamlines, + GroupScenario scenario, + bool add_dps, + bool add_dpv, + LengthProfile profile, + zip_uint32_t compression) { + size_t processes = parse_env_size("TRX_BENCH_PROCESSES", 1); + const size_t mp_min_streamlines = parse_env_size("TRX_BENCH_MP_MIN_STREAMLINES", 1000000); + if (streamlines < mp_min_streamlines) { + processes = 1; + } + if (processes <= 1) { + return build_trx_file_on_disk_single(streamlines, scenario, add_dps, add_dpv, profile, compression); + } +#if defined(__unix__) || defined(__APPLE__) + g_cleanup_owner_pid = getpid(); + const std::string shard_root = make_work_dir_name("trx_shards"); + { + std::error_code ec; + trx::fs::create_directories(shard_root, ec); + if (ec) { + throw std::runtime_error("Failed to create shard root: " + shard_root); + } + } + { + const std::string marker = shard_root + trx::SEPARATOR + "SHARD_ROOT_CREATED"; + std::ofstream out(marker, std::ios::out | std::ios::trunc); + if (out.is_open()) { + out << "ok\n"; + out.flush(); + out.close(); + } + } + if (parse_env_bool("TRX_BENCH_LOG", false)) { + std::cerr << "[trx-bench] shard_root " << shard_root << std::endl; + } + std::vector counts(processes, streamlines / processes); + const size_t remainder = streamlines % processes; + for (size_t i = 0; i < remainder; ++i) { + counts[i] += 1; + } + + std::vector shard_paths(processes); + std::vector status_paths(processes); + for (size_t i = 0; i < processes; ++i) { + shard_paths[i] = shard_root + trx::SEPARATOR + "shard_" + std::to_string(i); + status_paths[i] = shard_root + trx::SEPARATOR + "shard_" + std::to_string(i) + ".status"; + + // Pre-create shard directories to validate filesystem writability before forking. + // finalize_directory_persistent() will use these existing directories without + // removing them, avoiding race conditions in the multiprocess workflow. + std::error_code ec; + trx::fs::create_directories(shard_paths[i], ec); + if (ec) { + throw std::runtime_error("Failed to create shard dir: " + shard_paths[i] + " " + ec.message()); + } + std::ofstream status(status_paths[i], std::ios::out | std::ios::trunc); + if (status.is_open()) { + status << "pending\n"; + } + } + if (parse_env_bool("TRX_BENCH_LOG", false)) { + for (size_t i = 0; i < processes; ++i) { + std::cerr << "[trx-bench] shard_path[" << i << "] " << shard_paths[i] << std::endl; + } + } + + std::vector pids; + pids.reserve(processes); + for (size_t i = 0; i < processes; ++i) { + const pid_t pid = fork(); + if (pid == 0) { + try { + setenv("TRX_BENCH_THREADS", "1", 1); + setenv("TRX_BENCH_BATCH", "1000", 1); + setenv("TRX_BENCH_LOG", "0", 1); + setenv("TRX_BENCH_SHARD_INDEX", std::to_string(i).c_str(), 1); + if (parse_env_bool("TRX_BENCH_LOG", false)) { + std::cerr << "[trx-bench] shard_child_start path=" << shard_paths[i] << std::endl; + } + { + std::ofstream status(status_paths[i], std::ios::out | std::ios::trunc); + if (status.is_open()) { + status << "started pid=" << getpid() << "\n"; + status.flush(); + } + } + build_trx_shard(shard_paths[i], counts[i], scenario, add_dps, add_dpv, profile, compression); + { + std::ofstream status(status_paths[i], std::ios::out | std::ios::trunc); + if (status.is_open()) { + status << "ok\n"; + status.flush(); + } + } + _exit(0); + } catch (const std::exception &ex) { + std::ofstream out(status_paths[i], std::ios::out | std::ios::trunc); + if (out.is_open()) { + out << ex.what() << "\n"; + out.flush(); + out.close(); + } + _exit(1); + } catch (...) { + std::ofstream out(status_paths[i], std::ios::out | std::ios::trunc); + if (out.is_open()) { + out << "Unknown error\n"; + out.flush(); + out.close(); + } + _exit(1); + } + } + if (pid < 0) { + throw std::runtime_error("Failed to fork shard process"); + } + pids.push_back(pid); + } + + for (size_t i = 0; i < pids.size(); ++i) { + const auto pid = pids[i]; + int status = 0; + waitpid(pid, &status, 0); + if (!WIFEXITED(status) || WEXITSTATUS(status) != 0) { + std::string detail; + std::ifstream in(status_paths[i]); + if (in.is_open()) { + std::getline(in, detail); + } + if (detail.empty()) { + detail = "No status file content"; + } + throw std::runtime_error("Shard process failed: " + detail); + } + } + + const size_t shard_wait_ms = parse_env_size("TRX_BENCH_SHARD_WAIT_MS", 10000); + wait_for_shard_ok(shard_paths, status_paths, shard_wait_ms); + + size_t total_vertices = 0; + size_t total_streamlines = 0; + std::vector shard_vertices(processes, 0); + std::vector shard_streamlines(processes, 0); + for (size_t i = 0; i < processes; ++i) { + const auto ok_path = trx::fs::path(shard_paths[i]) / "SHARD_OK"; + std::error_code ok_ec; + if (!trx::fs::exists(ok_path, ok_ec)) { + std::string detail; + std::ifstream in(status_paths[i]); + if (in.is_open()) { + std::getline(in, detail); + } + if (detail.empty()) { + detail = "SHARD_OK missing for " + shard_paths[i]; + } + throw std::runtime_error("Shard process failed: " + detail); + } + std::error_code ec; + if (!trx::fs::exists(shard_paths[i], ec) || !trx::fs::is_directory(shard_paths[i], ec)) { + const auto root_files = list_files(shard_root); + std::string detail = "Shard output directory missing: " + shard_paths[i]; + if (!root_files.empty()) { + detail += " root_files=["; + for (size_t j = 0; j < root_files.size(); ++j) { + if (j > 0) { + detail += ","; + } + detail += root_files[j]; + } + detail += "]"; + } + throw std::runtime_error(detail); + } + const auto header_path = trx::fs::path(shard_paths[i]) / "header.json"; + if (!trx::fs::exists(header_path, ec)) { + const auto files = list_files(shard_paths[i]); + std::string detail = "Shard missing header.json: " + header_path.string(); + if (!files.empty()) { + detail += " files=["; + for (size_t j = 0; j < files.size(); ++j) { + if (j > 0) { + detail += ","; + } + detail += files[j]; + } + detail += "]"; + } + const auto root_files = list_files(shard_root); + if (!root_files.empty()) { + detail += " root_files=["; + for (size_t j = 0; j < root_files.size(); ++j) { + if (j > 0) { + detail += ","; + } + detail += root_files[j]; + } + detail += "]"; + } + throw std::runtime_error(detail); + } + const auto counts = read_header_counts(shard_paths[i]); + shard_streamlines[i] = counts.first; + shard_vertices[i] = counts.second; + total_streamlines += counts.first; + total_vertices += counts.second; + } + + const auto merge_start = std::chrono::steady_clock::now(); + const auto group_count = group_count_for(scenario); + const auto &group_names = group_names_for(scenario); + + const std::string merge_dir = make_temp_dir_path("trx_merge"); + const auto shard_positions0 = find_file_by_prefix(shard_paths[0], "positions."); + const auto shard_offsets0 = find_file_by_prefix(shard_paths[0], "offsets."); + if (shard_positions0.empty()) { + throw std::runtime_error("Missing positions file in first shard: " + shard_paths[0]); + } + if (shard_offsets0.empty()) { + throw std::runtime_error("Missing offsets file in first shard: " + shard_paths[0]); + } + const auto positions_filename = trx::fs::path(shard_positions0).filename().string(); + const auto offsets_filename = trx::fs::path(shard_offsets0).filename().string(); + const auto positions_path = trx::fs::path(merge_dir) / positions_filename; + const auto offsets_path = trx::fs::path(merge_dir) / offsets_filename; + + { + std::ofstream out_pos(positions_path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out_pos.is_open()) { + throw std::runtime_error("Failed to open output positions file: " + positions_path.string()); + } + } + { + std::ofstream out_off(offsets_path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out_off.is_open()) { + throw std::runtime_error("Failed to open output offsets file: " + offsets_path.string()); + } + } + + std::vector dps_files; + std::vector dpv_files; + if (add_dps) { + dps_files = list_files((trx::fs::path(shard_paths[0]) / "dps").string()); + if (dps_files.empty()) { + throw std::runtime_error("No DPS files found in shard: " + shard_paths[0]); + } + } + if (add_dpv) { + dpv_files = list_files((trx::fs::path(shard_paths[0]) / "dpv").string()); + if (dpv_files.empty()) { + throw std::runtime_error("No DPV files found in shard: " + shard_paths[0]); + } + } + std::vector group_files; + if (group_count > 0) { + group_files = list_files((trx::fs::path(shard_paths[0]) / "groups").string()); + if (group_files.empty()) { + throw std::runtime_error("No group files found in shard: " + shard_paths[0]); + } + } + + if (add_dps) { + trx::fs::create_directories(trx::fs::path(merge_dir) / "dps"); + for (const auto &name : dps_files) { + const auto dst = trx::fs::path(merge_dir) / "dps" / name; + std::ofstream out(dst, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to create DPS file: " + dst.string()); + } + } + } + if (add_dpv) { + trx::fs::create_directories(trx::fs::path(merge_dir) / "dpv"); + for (const auto &name : dpv_files) { + const auto dst = trx::fs::path(merge_dir) / "dpv" / name; + std::ofstream out(dst, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to create DPV file: " + dst.string()); + } + } + } + if (group_count > 0) { + trx::fs::create_directories(trx::fs::path(merge_dir) / "groups"); + for (const auto &name : group_files) { + const auto dst = trx::fs::path(merge_dir) / "groups" / name; + std::ofstream out(dst, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to create group file: " + dst.string()); + } + } + } + + size_t vertex_offset = 0; + size_t streamline_offset = 0; + for (size_t i = 0; i < processes; ++i) { + const auto shard_dir = shard_paths[i]; + const auto shard_positions = find_file_by_prefix(shard_dir, "positions."); + const auto shard_offsets = find_file_by_prefix(shard_dir, "offsets."); + if (shard_positions.empty()) { + throw std::runtime_error("Missing positions file in shard: " + shard_dir); + } + if (shard_offsets.empty()) { + throw std::runtime_error("Missing offsets file in shard: " + shard_dir); + } + + copy_file_append(shard_positions, positions_path.string()); + + { + const bool offsets_u32 = offsets_filename.find("uint32") != std::string::npos; + std::ifstream in(shard_offsets, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open shard offsets: " + shard_offsets); + } + std::ofstream out(offsets_path, std::ios::binary | std::ios::out | std::ios::app); + if (!out.is_open()) { + throw std::runtime_error("Failed to open output offsets file: " + offsets_path.string()); + } + constexpr size_t kBatch = 1 << 14; + const bool skip_first_value = (i != 0); + bool skipped_first = false; + if (offsets_u32) { + std::vector buffer(kBatch); + while (in) { + in.read(reinterpret_cast(buffer.data()), + static_cast(buffer.size() * sizeof(uint32_t))); + const std::streamsize count = in.gcount(); + if (count <= 0) { + break; + } + const size_t elems = static_cast(count) / sizeof(uint32_t); + size_t start = 0; + if (skip_first_value && !skipped_first) { + start = 1; + skipped_first = true; + } + for (size_t j = start; j < elems; ++j) { + const uint64_t value = static_cast(buffer[j]) + static_cast(vertex_offset); + if (value > std::numeric_limits::max()) { + throw std::runtime_error("Offsets overflow uint32 during merge."); + } + buffer[j] = static_cast(value); + } + if (elems > start) { + out.write(reinterpret_cast(buffer.data() + start), + static_cast((elems - start) * sizeof(uint32_t))); + } + } + } else { + std::vector buffer(kBatch); + while (in) { + in.read(reinterpret_cast(buffer.data()), + static_cast(buffer.size() * sizeof(uint64_t))); + const std::streamsize count = in.gcount(); + if (count <= 0) { + break; + } + const size_t elems = static_cast(count) / sizeof(uint64_t); + size_t start = 0; + if (skip_first_value && !skipped_first) { + start = 1; + skipped_first = true; + } + for (size_t j = start; j < elems; ++j) { + buffer[j] += static_cast(vertex_offset); + } + if (elems > start) { + out.write(reinterpret_cast(buffer.data() + start), + static_cast((elems - start) * sizeof(uint64_t))); + } + } + } + } + + if (add_dps) { + const auto shard_dps = trx::fs::path(shard_dir) / "dps"; + for (const auto &name : dps_files) { + const auto src = shard_dps / name; + const auto dst = trx::fs::path(merge_dir) / "dps" / name; + if (!trx::fs::exists(src)) { + throw std::runtime_error("Missing DPS file in shard: " + src.string()); + } + copy_file_append(src.string(), dst.string()); + } + } + + if (add_dpv) { + const auto shard_dpv = trx::fs::path(shard_dir) / "dpv"; + for (const auto &name : dpv_files) { + const auto src = shard_dpv / name; + const auto dst = trx::fs::path(merge_dir) / "dpv" / name; + if (!trx::fs::exists(src)) { + throw std::runtime_error("Missing DPV file in shard: " + src.string()); + } + copy_file_append(src.string(), dst.string()); + } + } + + if (group_count > 0) { + const auto shard_groups = trx::fs::path(shard_dir) / "groups"; + for (const auto &name : group_files) { + const auto src = shard_groups / name; + const auto dst = trx::fs::path(merge_dir) / "groups" / name; + if (!trx::fs::exists(src)) { + throw std::runtime_error("Missing group file in shard: " + src.string()); + } + std::ifstream in(src, std::ios::binary); + if (!in.is_open()) { + throw std::runtime_error("Failed to open shard group: " + src.string()); + } + std::ofstream out(dst, std::ios::binary | std::ios::out | std::ios::app); + if (!out.is_open()) { + throw std::runtime_error("Failed to open output group file: " + dst.string()); + } + constexpr size_t kBatch = 1 << 14; + std::vector buffer(kBatch); + while (in) { + in.read(reinterpret_cast(buffer.data()), + static_cast(buffer.size() * sizeof(uint32_t))); + const std::streamsize count = in.gcount(); + if (count <= 0) { + break; + } + const size_t elems = static_cast(count) / sizeof(uint32_t); + for (size_t j = 0; j < elems; ++j) { + buffer[j] += static_cast(streamline_offset); + } + out.write(reinterpret_cast(buffer.data()), + static_cast(elems * sizeof(uint32_t))); + } + } + } + + vertex_offset += shard_vertices[i]; + streamline_offset += shard_streamlines[i]; + } + + // Read header before cleanup to avoid accessing deleted files + const json header_json = read_header_json(shard_paths[0]); + json::object header_obj = header_json.object_items(); + header_obj["NB_VERTICES"] = json(static_cast(total_vertices)); + header_obj["NB_STREAMLINES"] = json(static_cast(total_streamlines)); + const json header = header_obj; + { + const auto header_path = trx::fs::path(merge_dir) / "header.json"; + std::ofstream out(header_path, std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to write header.json: " + header_path.string()); + } + out << header.dump(); + } + + const std::string zip_path = make_temp_path("trx_input"); + int errorp; + zip_t *zf = zip_open(zip_path.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp); + if (zf == nullptr) { + throw std::runtime_error("Could not open archive " + zip_path + ": " + strerror(errorp)); + } + const std::string header_payload = header.dump() + "\n"; + zip_source_t *header_source = + zip_source_buffer(zf, header_payload.data(), header_payload.size(), 0 /* do not free */); + if (header_source == nullptr) { + zip_close(zf); + throw std::runtime_error("Failed to create zip source for header.json: " + std::string(zip_strerror(zf))); + } + const zip_int64_t header_idx = zip_file_add(zf, "header.json", header_source, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE); + if (header_idx < 0) { + zip_source_free(header_source); + zip_close(zf); + throw std::runtime_error("Failed to add header.json to archive: " + std::string(zip_strerror(zf))); + } + const zip_int32_t compression_mode = static_cast(compression); + if (zip_set_file_compression(zf, header_idx, compression_mode, 0) < 0) { + zip_close(zf); + throw std::runtime_error("Failed to set compression for header.json: " + std::string(zip_strerror(zf))); + } + const std::unordered_set skip = {"header.json"}; + trx::zip_from_folder(zf, merge_dir, merge_dir, compression, &skip); + if (zip_close(zf) != 0) { + throw std::runtime_error("Unable to close archive " + zip_path + ": " + zip_strerror(zf)); + } + trx::fs::remove_all(merge_dir); + const std::string out_path = zip_path; + + register_cleanup(out_path); + const auto merge_end = std::chrono::steady_clock::now(); + const std::chrono::duration merge_elapsed = merge_end - merge_start; + + // Final cleanup of shard directories after merge is complete + if (!parse_env_bool("TRX_BENCH_KEEP_SHARDS", false)) { + std::error_code ec; + trx::fs::remove_all(shard_root, ec); + } + return {out_path, streamlines, total_vertices, merge_elapsed.count(), processes}; +#else + (void)processes; + return build_trx_file_on_disk_single(streamlines, scenario, add_dps, add_dpv, profile, compression); +#endif +} + +struct QueryDataset { + std::unique_ptr> trx; + std::vector> aabbs; + std::vector> slab_mins; + std::vector> slab_maxs; +}; + +void build_slabs(std::vector> &mins, std::vector> &maxs) { + mins.clear(); + maxs.clear(); + mins.reserve(kSlabCount); + maxs.reserve(kSlabCount); + const float z_range = kFov.max_z - kFov.min_z; + for (size_t i = 0; i < kSlabCount; ++i) { + const float t = (kSlabCount == 1) ? 0.5f : static_cast(i) / static_cast(kSlabCount - 1); + const float center_z = kFov.min_z + t * z_range; + const float min_z = std::max(kFov.min_z, center_z - kSlabThicknessMm * 0.5f); + const float max_z = std::min(kFov.max_z, center_z + kSlabThicknessMm * 0.5f); + mins.push_back({kFov.min_x, kFov.min_y, min_z}); + maxs.push_back({kFov.max_x, kFov.max_y, max_z}); + } +} + +struct ScenarioParams { + size_t streamlines = 0; + GroupScenario scenario = GroupScenario::None; + bool add_dps = false; + bool add_dpv = false; + LengthProfile profile = LengthProfile::Mixed; +}; + +struct KeyHash { + using Key = std::tuple; + size_t operator()(const Key &key) const { + size_t h = 0; + auto hash_combine = [&](size_t v) { + h ^= v + 0x9e3779b97f4a7c15ULL + (h << 6) + (h >> 2); + }; + hash_combine(std::hash{}(std::get<0>(key))); + hash_combine(std::hash{}(std::get<1>(key))); + hash_combine(std::hash{}(std::get<2>(key))); + hash_combine(std::hash{}(std::get<3>(key))); + return h; + } +}; + +void maybe_write_query_timings(const ScenarioParams &scenario, const std::vector &timings_ms) { + static std::mutex mutex; + static std::unordered_set seen; + const KeyHash::Key key{scenario.streamlines, + static_cast(scenario.scenario), + scenario.add_dps ? 1 : 0, + scenario.add_dpv ? 1 : 0}; + + std::lock_guard lock(mutex); + if (!seen.insert(key).second) { + return; + } + + const char *env_path = std::getenv("TRX_QUERY_TIMINGS_PATH"); + const std::filesystem::path out_path = env_path ? env_path : "bench/query_timings.jsonl"; + std::error_code ec; + if (!out_path.parent_path().empty()) { + std::filesystem::create_directories(out_path.parent_path(), ec); + } + std::ofstream out(out_path, std::ios::app); + if (!out.is_open()) { + return; + } + + out << "{" + << "\"streamlines\":" << scenario.streamlines << "," + << "\"group_case\":" << static_cast(scenario.scenario) << "," + << "\"group_count\":" << group_count_for(scenario.scenario) << "," + << "\"dps\":" << (scenario.add_dps ? 1 : 0) << "," + << "\"dpv\":" << (scenario.add_dpv ? 1 : 0) << "," + << "\"slab_thickness_mm\":" << kSlabThicknessMm << "," + << "\"timings_ms\":["; + for (size_t i = 0; i < timings_ms.size(); ++i) { + if (i > 0) { + out << ","; + } + out << timings_ms[i]; + } + out << "]}\n"; +} +} // namespace + +static void BM_TrxFileSize_Float16(benchmark::State &state) { + const size_t streamlines = static_cast(state.range(0)); + const auto profile = static_cast(state.range(1)); + const bool add_dps = state.range(2) != 0; + const bool add_dpv = state.range(3) != 0; + const bool use_zip = state.range(4) != 0; + const auto compression = use_zip ? ZIP_CM_DEFLATE : ZIP_CM_STORE; + const size_t skip_zip_at = parse_env_size("TRX_BENCH_SKIP_ZIP_AT", 5000000); + if (use_zip && streamlines >= skip_zip_at) { + state.SkipWithMessage("zip compression skipped for large streamlines"); + return; + } + log_bench_start("BM_TrxFileSize_Float16", + "streamlines=" + std::to_string(streamlines) + " profile=" + std::to_string(state.range(1)) + + " dps=" + std::to_string(static_cast(add_dps)) + + " dpv=" + std::to_string(static_cast(add_dpv)) + + " compression=" + std::to_string(static_cast(use_zip))); + + double total_write_ms = 0.0; + double total_file_bytes = 0.0; + double total_merge_ms = 0.0; + double total_build_ms = 0.0; + double total_merge_processes = 0.0; + for (auto _ : state) { + const auto start = std::chrono::steady_clock::now(); + const auto on_disk = + build_trx_file_on_disk(streamlines, GroupScenario::None, add_dps, add_dpv, profile, compression); + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = end - start; + total_build_ms += elapsed.count(); + total_merge_ms += on_disk.shard_merge_ms; + total_merge_processes += static_cast(on_disk.shard_processes); + total_write_ms += elapsed.count(); + total_file_bytes += static_cast(file_size_bytes(on_disk.path)); + } + + state.counters["streamlines"] = static_cast(streamlines); + state.counters["length_profile"] = static_cast(state.range(1)); + state.counters["dps"] = add_dps ? 1.0 : 0.0; + state.counters["dpv"] = add_dpv ? 1.0 : 0.0; + state.counters["compression"] = use_zip ? 1.0 : 0.0; + state.counters["positions_dtype"] = 16.0; + state.counters["write_ms"] = total_write_ms / static_cast(state.iterations()); + state.counters["build_ms"] = total_build_ms / static_cast(state.iterations()); + if (total_merge_ms > 0.0) { + state.counters["shard_merge_ms"] = total_merge_ms / static_cast(state.iterations()); + state.counters["shard_processes"] = total_merge_processes / static_cast(state.iterations()); + } + state.counters["file_bytes"] = total_file_bytes / static_cast(state.iterations()); + state.counters["max_rss_kb"] = get_max_rss_kb(); + + log_bench_end("BM_TrxFileSize_Float16", + "streamlines=" + std::to_string(streamlines) + " profile=" + std::to_string(state.range(1))); +} + +static void BM_TrxStream_TranslateWrite(benchmark::State &state) { + const size_t streamlines = static_cast(state.range(0)); + const auto scenario = static_cast(state.range(1)); + const bool add_dps = state.range(2) != 0; + const bool add_dpv = state.range(3) != 0; + const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); + log_bench_config("translate_write", bench_threads(), std::max(1, bench_batch_size())); + log_bench_start("BM_TrxStream_TranslateWrite", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1)) + + " dps=" + std::to_string(static_cast(add_dps)) + + " dpv=" + std::to_string(static_cast(add_dpv))); + + using Key = KeyHash::Key; + static std::unordered_map cache; + + const Key key{streamlines, static_cast(scenario), add_dps ? 1 : 0, add_dpv ? 1 : 0}; + if (cache.find(key) == cache.end()) { + state.PauseTiming(); + cache.emplace(key, + build_trx_file_on_disk(streamlines, scenario, add_dps, add_dpv, LengthProfile::Mixed, ZIP_CM_STORE)); + state.ResumeTiming(); + } + + const auto &dataset = cache.at(key); + if (dataset.shard_processes > 1 && dataset.shard_merge_ms > 0.0) { + state.counters["shard_merge_ms"] = dataset.shard_merge_ms; + state.counters["shard_processes"] = static_cast(dataset.shard_processes); + } + for (auto _ : state) { + const auto start = std::chrono::steady_clock::now(); + auto trx = trx::load_any(dataset.path); + const size_t chunk_bytes = parse_env_size("TRX_BENCH_CHUNK_BYTES", 1024ULL * 1024ULL * 1024ULL); + const std::string out_dir = make_work_dir_name("trx_translate_chunk"); + const auto out_info = trx::prepare_positions_output(trx, out_dir); + + std::ofstream out_positions(out_info.positions_path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out_positions.is_open()) { + throw std::runtime_error("Failed to open output positions file: " + out_info.positions_path); + } + + trx.for_each_positions_chunk(chunk_bytes, + [&](trx::TrxScalarType dtype, const void *data, size_t offset, size_t count) { + (void)offset; + if (progress_every > 0 && ((offset + count) % progress_every == 0)) { + std::cerr << "[trx-bench] progress translate points=" << (offset + count) + << " / " << out_info.points << std::endl; + } + const size_t total_vals = count * 3; + if (dtype == trx::TrxScalarType::Float16) { + const auto *src = reinterpret_cast(data); + std::vector tmp(total_vals); + for (size_t i = 0; i < total_vals; ++i) { + tmp[i] = static_cast(static_cast(src[i]) + 1.0f); + } + out_positions.write(reinterpret_cast(tmp.data()), + static_cast(tmp.size() * sizeof(Eigen::half))); + } else if (dtype == trx::TrxScalarType::Float32) { + const auto *src = reinterpret_cast(data); + std::vector tmp(total_vals); + for (size_t i = 0; i < total_vals; ++i) { + tmp[i] = src[i] + 1.0f; + } + out_positions.write(reinterpret_cast(tmp.data()), + static_cast(tmp.size() * sizeof(float))); + } else { + const auto *src = reinterpret_cast(data); + std::vector tmp(total_vals); + for (size_t i = 0; i < total_vals; ++i) { + tmp[i] = src[i] + 1.0; + } + out_positions.write(reinterpret_cast(tmp.data()), + static_cast(tmp.size() * sizeof(double))); + } + }); + out_positions.flush(); + out_positions.close(); + + const std::string out_path = make_temp_path("trx_translate"); + int errorp; + zip_t *zf = zip_open(out_path.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp); + if (zf == nullptr) { + trx::rm_dir(out_dir); + throw std::runtime_error("Could not open archive " + out_path + ": " + strerror(errorp)); + } + trx::zip_from_folder(zf, out_dir, out_dir, ZIP_CM_STORE, nullptr); + if (zip_close(zf) != 0) { + trx::rm_dir(out_dir); + throw std::runtime_error("Unable to close archive " + out_path + ": " + zip_strerror(zf)); + } + trx::rm_dir(out_dir); + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = end - start; + state.SetIterationTime(elapsed.count()); + + std::error_code ec; + std::filesystem::remove(out_path, ec); + benchmark::DoNotOptimize(trx); + } + + state.counters["streamlines"] = static_cast(streamlines); + state.counters["group_case"] = static_cast(state.range(1)); + state.counters["group_count"] = static_cast(group_count_for(scenario)); + state.counters["dps"] = add_dps ? 1.0 : 0.0; + state.counters["dpv"] = add_dpv ? 1.0 : 0.0; + state.counters["length_profile"] = static_cast(static_cast(LengthProfile::Mixed)); + state.counters["positions_dtype"] = 16.0; + state.counters["max_rss_kb"] = get_max_rss_kb(); + + log_bench_end("BM_TrxStream_TranslateWrite", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1))); +} + +static void BM_TrxQueryAabb_Slabs(benchmark::State &state) { + const size_t streamlines = static_cast(state.range(0)); + const auto scenario = static_cast(state.range(1)); + const bool add_dps = state.range(2) != 0; + const bool add_dpv = state.range(3) != 0; + log_bench_start("BM_TrxQueryAabb_Slabs", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1)) + + " dps=" + std::to_string(static_cast(add_dps)) + + " dpv=" + std::to_string(static_cast(add_dpv))); + + using Key = KeyHash::Key; + static std::unordered_map cache; + + const Key key{streamlines, static_cast(scenario), add_dps ? 1 : 0, add_dpv ? 1 : 0}; + if (cache.find(key) == cache.end()) { + state.PauseTiming(); + QueryDataset dataset; + auto on_disk = build_trx_file_on_disk(streamlines, scenario, add_dps, add_dpv, LengthProfile::Mixed, ZIP_CM_STORE); + dataset.trx = trx::load(on_disk.path); + dataset.aabbs = dataset.trx->build_streamline_aabbs(); + build_slabs(dataset.slab_mins, dataset.slab_maxs); + cache.emplace(key, std::move(dataset)); + state.ResumeTiming(); + } + + auto &dataset = cache.at(key); + for (auto _ : state) { + std::vector slab_times_ms; + slab_times_ms.reserve(kSlabCount); + + const auto start = std::chrono::steady_clock::now(); + size_t total = 0; + for (size_t i = 0; i < kSlabCount; ++i) { + const auto &min_corner = dataset.slab_mins[i]; + const auto &max_corner = dataset.slab_maxs[i]; + const auto q_start = std::chrono::steady_clock::now(); + auto subset = dataset.trx->query_aabb(min_corner, max_corner, &dataset.aabbs); + const auto q_end = std::chrono::steady_clock::now(); + const std::chrono::duration q_elapsed = q_end - q_start; + slab_times_ms.push_back(q_elapsed.count()); + total += subset->num_streamlines(); + subset->close(); + } + const auto end = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed = end - start; + state.SetIterationTime(elapsed.count()); + benchmark::DoNotOptimize(total); + + auto sorted = slab_times_ms; + std::sort(sorted.begin(), sorted.end()); + const auto p50 = sorted[sorted.size() / 2]; + const auto p95_idx = static_cast(std::ceil(0.95 * sorted.size())) - 1; + const auto p95 = sorted[std::min(p95_idx, sorted.size() - 1)]; + state.counters["query_p50_ms"] = p50; + state.counters["query_p95_ms"] = p95; + + ScenarioParams params; + params.streamlines = streamlines; + params.scenario = scenario; + params.add_dps = add_dps; + params.add_dpv = add_dpv; + params.profile = LengthProfile::Mixed; + maybe_write_query_timings(params, slab_times_ms); + } + + state.counters["streamlines"] = static_cast(streamlines); + state.counters["group_case"] = static_cast(state.range(1)); + state.counters["group_count"] = static_cast(group_count_for(scenario)); + state.counters["dps"] = add_dps ? 1.0 : 0.0; + state.counters["dpv"] = add_dpv ? 1.0 : 0.0; + state.counters["query_count"] = static_cast(kSlabCount); + state.counters["slab_thickness_mm"] = kSlabThicknessMm; + state.counters["positions_dtype"] = 16.0; + state.counters["max_rss_kb"] = get_max_rss_kb(); + + log_bench_end("BM_TrxQueryAabb_Slabs", + "streamlines=" + std::to_string(streamlines) + " group_case=" + std::to_string(state.range(1))); +} + +static void ApplySizeArgs(benchmark::internal::Benchmark *bench) { + const std::array profiles = {static_cast(LengthProfile::Short), + static_cast(LengthProfile::Medium), + static_cast(LengthProfile::Long)}; + const std::array flags = {0, 1}; + const auto counts_desc = streamlines_for_benchmarks(); + for (const auto count : counts_desc) { + for (const auto profile : profiles) { + for (const auto dps : flags) { + for (const auto dpv : flags) { + for (const auto compression : flags) { + bench->Args({static_cast(count), profile, dps, dpv, compression}); + } + } + } + } + } +} + +static void ApplyStreamArgs(benchmark::internal::Benchmark *bench) { + const std::array groups = {static_cast(GroupScenario::None), + static_cast(GroupScenario::Bundles), + static_cast(GroupScenario::Connectome)}; + const std::array flags = {0, 1}; + const auto counts_desc = streamlines_for_benchmarks(); + for (const auto count : counts_desc) { + for (const auto group_case : groups) { + for (const auto dps : flags) { + for (const auto dpv : flags) { + bench->Args({static_cast(count), group_case, dps, dpv}); + } + } + } + } +} + +static void ApplyQueryArgs(benchmark::internal::Benchmark *bench) { + const std::array groups = {static_cast(GroupScenario::None), + static_cast(GroupScenario::Bundles), + static_cast(GroupScenario::Connectome)}; + const std::array flags = {0, 1}; + const auto counts_desc = streamlines_for_benchmarks(); + for (const auto count : counts_desc) { + for (const auto group_case : groups) { + for (const auto dps : flags) { + for (const auto dpv : flags) { + bench->Args({static_cast(count), group_case, dps, dpv}); + } + } + } + } + bench->Iterations(1); +} + +BENCHMARK(BM_TrxFileSize_Float16) + ->Apply(ApplySizeArgs) + ->Unit(benchmark::kMillisecond); + +BENCHMARK(BM_TrxStream_TranslateWrite) + ->Apply(ApplyStreamArgs) + ->UseManualTime() + ->Unit(benchmark::kMillisecond); + +BENCHMARK(BM_TrxQueryAabb_Slabs) + ->Apply(ApplyQueryArgs) + ->UseManualTime() + ->Unit(benchmark::kMillisecond); + +int main(int argc, char **argv) { + ::benchmark::Initialize(&argc, argv); + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) { + return 1; + } + try { + ::benchmark::RunSpecifiedBenchmarks(); + g_run_success = true; + } catch (const std::exception &ex) { + std::cerr << "Benchmark failed: " << ex.what() << std::endl; + return 1; + } catch (...) { + std::cerr << "Benchmark failed with unknown exception." << std::endl; + return 1; + } + return 0; +} diff --git a/bench/plot_bench.py b/bench/plot_bench.py new file mode 100644 index 0000000..6b0aa3d --- /dev/null +++ b/bench/plot_bench.py @@ -0,0 +1,214 @@ +import argparse +import json +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +LENGTH_LABELS = { + 0: "mixed", + 1: "short (20-120mm)", + 2: "medium (80-260mm)", + 3: "long (200-500mm)", +} +GROUP_LABELS = { + 0: "no groups", + 1: "bundle groups (80)", + 2: "connectome groups (4950)", +} +COMPRESSION_LABELS = {0: "store (no zip)", 1: "zip deflate"} + + +def _parse_base_name(name: str) -> str: + return name.split("/")[0] + + +def _time_to_ms(bench: dict) -> float: + value = bench.get("real_time", 0.0) + unit = bench.get("time_unit", "ns") + if unit == "ns": + return value / 1e6 + if unit == "us": + return value / 1e3 + if unit == "ms": + return value + if unit == "s": + return value * 1e3 + return value / 1e6 + + +def load_benchmarks(path: Path) -> pd.DataFrame: + with path.open() as f: + data = json.load(f) + + rows = [] + for bench in data.get("benchmarks", []): + name = bench.get("name", "") + if not name.startswith("BM_"): + continue + rows.append( + { + "name": name, + "base": _parse_base_name(name), + "real_time_ms": _time_to_ms(bench), + "streamlines": bench.get("streamlines"), + "length_profile": bench.get("length_profile"), + "compression": bench.get("compression"), + "group_case": bench.get("group_case"), + "group_count": bench.get("group_count"), + "dps": bench.get("dps"), + "dpv": bench.get("dpv"), + "write_ms": bench.get("write_ms"), + "file_bytes": bench.get("file_bytes"), + "max_rss_kb": bench.get("max_rss_kb"), + "query_p50_ms": bench.get("query_p50_ms"), + "query_p95_ms": bench.get("query_p95_ms"), + } + ) + + return pd.DataFrame(rows) + + +def plot_file_sizes(df: pd.DataFrame, output_dir: Path) -> None: + sub = df[df["base"] == "BM_TrxFileSize_Float16"].copy() + if sub.empty: + return + sub["file_mb"] = sub["file_bytes"] / 1e6 + sub["length_label"] = sub["length_profile"].map(LENGTH_LABELS) + sub["dp_label"] = "dpv=" + sub["dpv"].astype(int).astype(str) + ", dps=" + sub["dps"].astype(int).astype(str) + + fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharey=True) + for compression, ax in zip([0, 1], axes): + scomp = sub[sub["compression"] == compression] + for (length_label, dp_label), series in scomp.groupby(["length_label", "dp_label"]): + series = series.sort_values("streamlines") + ax.plot( + series["streamlines"], + series["file_mb"], + marker="o", + label=f"{length_label}, {dp_label}", + ) + ax.set_title(COMPRESSION_LABELS.get(compression, str(compression))) + ax.set_xlabel("streamlines") + ax.grid(True) + ax.legend(loc="best", fontsize="x-small") + + axes[0].set_ylabel("file size (MB)") + fig.suptitle("TRX file size vs streamlines (float16 positions)") + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig(output_dir / "trx_size_vs_streamlines.png", dpi=160, bbox_inches="tight") + plt.close(fig) + + +def _plot_translate_series(df: pd.DataFrame, output_dir: Path, metric: str, ylabel: str, filename: str) -> None: + sub = df[df["base"] == "BM_TrxStream_TranslateWrite"].copy() + if sub.empty: + return + sub["group_label"] = sub["group_case"].map(GROUP_LABELS) + sub["dp_label"] = "dpv=" + sub["dpv"].astype(int).astype(str) + ", dps=" + sub["dps"].astype(int).astype(str) + + fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True) + for ax, (group_label, gsub) in zip(axes, sub.groupby("group_label")): + for dp_label, series in gsub.groupby("dp_label"): + series = series.sort_values("streamlines") + ax.plot(series["streamlines"], series[metric], marker="o", label=dp_label) + ax.set_title(group_label) + ax.set_xlabel("streamlines") + ax.grid(True) + ax.legend(loc="best", fontsize="x-small") + axes[0].set_ylabel(ylabel) + fig.suptitle("Translate + stream write throughput") + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig(output_dir / filename, dpi=160, bbox_inches="tight") + plt.close(fig) + + +def plot_translate_write(df: pd.DataFrame, output_dir: Path) -> None: + sub = df[df["base"] == "BM_TrxStream_TranslateWrite"].copy() + if sub.empty: + return + sub["rss_mb"] = sub["max_rss_kb"] / 1024.0 + _plot_translate_series( + sub, + output_dir, + metric="real_time_ms", + ylabel="time (ms)", + filename="trx_translate_write_time.png", + ) + _plot_translate_series( + sub, + output_dir, + metric="rss_mb", + ylabel="max RSS (MB)", + filename="trx_translate_write_rss.png", + ) + + +def load_query_timings(path: Path) -> list[dict]: + if not path.exists(): + return [] + rows = [] + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + rows.append(json.loads(line)) + return rows + + +def plot_query_timings(path: Path, output_dir: Path, group_case: int, dpv: int, dps: int) -> None: + rows = load_query_timings(path) + if not rows: + return + rows = [ + r + for r in rows + if r.get("group_case") == group_case and r.get("dpv") == dpv and r.get("dps") == dps + ] + if not rows: + return + rows.sort(key=lambda r: r["streamlines"]) + data = [r["timings_ms"] for r in rows] + labels = [str(r["streamlines"]) for r in rows] + + fig, ax = plt.subplots(figsize=(8, 4)) + ax.boxplot(data, labels=labels, showfliers=False) + ax.set_title( + f"Slab query timings ({GROUP_LABELS.get(group_case, group_case)}, dpv={dpv}, dps={dps})" + ) + ax.set_xlabel("streamlines") + ax.set_ylabel("per-slab query time (ms)") + ax.grid(True, axis="y") + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig(output_dir / "trx_query_slab_timings.png", dpi=160, bbox_inches="tight") + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Plot trx-cpp benchmark results.") + parser.add_argument("bench_json", type=Path, help="Path to benchmark JSON output.") + parser.add_argument("--query-json", type=Path, help="Path to slab timing JSONL file.") + parser.add_argument( + "--out-dir", + type=Path, + default=Path("docs/_static/benchmarks"), + help="Directory to save PNGs.", + ) + parser.add_argument("--group-case", type=int, default=0, help="Group case filter for query plot.") + parser.add_argument("--dpv", type=int, default=0, help="DPV filter for query plot.") + parser.add_argument("--dps", type=int, default=0, help="DPS filter for query plot.") + args = parser.parse_args() + + df = load_benchmarks(args.bench_json) + if df.empty: + raise SystemExit("No benchmarks found in JSON file.") + + plot_file_sizes(df, args.out_dir) + plot_translate_write(df, args.out_dir) + if args.query_json: + plot_query_timings(args.query_json, args.out_dir, args.group_case, args.dpv, args.dps) + + +if __name__ == "__main__": + main() diff --git a/docs/_static/benchmarks/.gitkeep b/docs/_static/benchmarks/.gitkeep new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/_static/benchmarks/.gitkeep @@ -0,0 +1 @@ + diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst new file mode 100644 index 0000000..3c5f468 --- /dev/null +++ b/docs/benchmarks.rst @@ -0,0 +1,118 @@ +Benchmarks +========== + +This page documents the benchmarking suite and how to interpret the results. +The benchmarks are designed for realistic tractography workloads (HPC scale), +not for CI. They focus on file size, throughput, and interactive spatial queries. + +Data model +---------- + +All benchmarks synthesize smooth, slightly curved streamlines in a realistic +field of view: + +- **Lengths:** random between 20 and 500 mm (profiles skew short/medium/long) +- **Field of view:** x = [-70, 70], y = [-108, 79], z = [-60, 75] (mm, RAS+) +- **Streamline counts:** 100k, 500k, 1M, 5M, 10M +- **Groups:** none, 80 bundle groups, or 4950 connectome groups (100 regions) +- **DPV/DPS:** either present (1 value) or absent + +Positions are stored as float16 to highlight storage efficiency. + +TRX size vs streamline count +---------------------------- + +This benchmark writes TRX files with float16 positions and measures the final +on-disk size for different streamline counts. It compares short/medium/long +length profiles, DPV/DPS presence, and zip compression (store vs deflate). + +.. figure:: _static/benchmarks/trx_size_vs_streamlines.png + :alt: TRX file size vs streamlines + :align: center + + File size (MB) as a function of streamline count. + +Translate + stream write throughput +----------------------------------- + +This benchmark loads a TRX file, iterates through every streamline, translates +each point by +1 mm in x/y/z, and streams the result into a new TRX file. It +reports total wall time and max RSS so researchers can understand throughput +and memory pressure on both clusters and laptops. + +.. figure:: _static/benchmarks/trx_translate_write_time.png + :alt: Translate + stream write time + :align: center + + End-to-end time for translating and rewriting streamlines. + +.. figure:: _static/benchmarks/trx_translate_write_rss.png + :alt: Translate + stream write RSS + :align: center + + Max RSS during translate + stream write. + +Spatial slab query latency +-------------------------- + +This benchmark precomputes per-streamline AABBs and then issues 100 spatial +queries using 5 mm slabs that sweep through the tractogram volume. Each slab +query mimics a GUI slice update and records its timing so distributions can be +visualized. + +.. figure:: _static/benchmarks/trx_query_slab_timings.png + :alt: Slab query timings + :align: center + + Distribution of per-slab query latency. + +Running the benchmarks +---------------------- + +Build and run the benchmarks, then plot results with matplotlib: + +.. code-block:: bash + + cmake -S . -B build -DTRX_BUILD_BENCHMARKS=ON + cmake --build build --target bench_trx_stream + + # Run benchmarks (this can be long for large datasets). + ./build/bench/bench_trx_stream \ + --benchmark_out=bench/results.json \ + --benchmark_out_format=json + + # Capture per-slab timings for query distributions. + TRX_QUERY_TIMINGS_PATH=bench/query_timings.jsonl \ + ./build/bench/bench_trx_stream \ + --benchmark_filter=BM_TrxQueryAabb_Slabs \ + --benchmark_out=bench/results.json \ + --benchmark_out_format=json + + # Optional: record RSS samples for file-size runs. + TRX_RSS_SAMPLES_PATH=bench/rss_samples.jsonl \ + TRX_RSS_SAMPLE_EVERY=50000 \ + TRX_RSS_SAMPLE_MS=500 \ + ./build/bench/bench_trx_stream \ + --benchmark_filter=BM_TrxFileSize_Float16 \ + --benchmark_out=bench/results.json \ + --benchmark_out_format=json + + # Generate plots into docs/_static/benchmarks. + python bench/plot_bench.py bench/results.json \ + --query-json bench/query_timings.jsonl \ + --out-dir docs/_static/benchmarks + +The query plot defaults to the "no groups, no DPV/DPS" case. Use +``--group-case``, ``--dpv``, and ``--dps`` in ``plot_bench.py`` to select other +scenarios. + +If zip compression is too slow or unstable for the largest datasets, set +``TRX_BENCH_SKIP_ZIP_AT`` (default 5000000) to skip compression for large +streamline counts. + +When running with multiprocessing, the benchmark uses +``finalize_directory_persistent()`` to write shard outputs without removing +pre-created directories, avoiding race conditions in the parallel workflow. You +can keep shard outputs for debugging by setting ``TRX_BENCH_KEEP_SHARDS=1``. The +merge step waits for each shard to finish (via ``SHARD_OK`` files); adjust the +timeout with ``TRX_BENCH_SHARD_WAIT_MS`` if needed. diff --git a/docs/index.rst b/docs/index.rst index 58f0d3f..229a769 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,6 +16,7 @@ tractography file format. building usage downstream_usage + benchmarks linting .. toctree:: diff --git a/docs/usage.rst b/docs/usage.rst index 16860ee..04b2619 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -51,6 +51,174 @@ Write a TRX file trx.save("tracks_copy.trx", ZIP_CM_STORE); trx.close(); +Thread-safe streaming pattern +----------------------------- + +``TrxStream`` is **not** thread-safe for concurrent writes. A common pattern for +multi-core streamline generation is to use worker threads for generation and a +single writer thread (or the main thread) to append to ``TrxStream``. + +.. code-block:: cpp + + #include + #include + #include + #include + #include + + struct Batch { + std::vector>> streamlines; + }; + + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + bool done = false; + + // Worker threads: generate streamlines and push batches. + auto producer = [&]() { + Batch batch; + batch.streamlines.reserve(1000); + for (int i = 0; i < 1000; ++i) { + std::vector> points = {/* ... generate ... */}; + batch.streamlines.push_back(std::move(points)); + } + { + std::lock_guard lock(mutex); + queue.push(std::move(batch)); + } + cv.notify_one(); + }; + + // Writer thread (single): pop batches and push into TrxStream. + trx::TrxStream stream("float16"); + auto consumer = [&]() { + for (;;) { + std::unique_lock lock(mutex); + cv.wait(lock, [&]() { return done || !queue.empty(); }); + if (queue.empty() && done) { + return; + } + Batch batch = std::move(queue.front()); + queue.pop(); + lock.unlock(); + + for (const auto &points : batch.streamlines) { + stream.push_streamline(points); + } + } + }; + + std::thread writer(consumer); + std::thread t1(producer); + std::thread t2(producer); + t1.join(); + t2.join(); + { + std::lock_guard lock(mutex); + done = true; + } + cv.notify_all(); + writer.join(); + + stream.finalize("tracks.trx", ZIP_CM_STORE); + +Process-based sharding and merge +-------------------------------- + +For large tractograms it is common to generate streamlines in separate +processes, write shard outputs, and merge them later. ``TrxStream`` provides two +finalization methods for directory output: + +- ``finalize_directory()`` — Single-process variant that removes any existing + directory before writing. Use when you control the entire lifecycle. + +- ``finalize_directory_persistent()`` — Multiprocess-safe variant that does NOT + remove existing directories. Use when coordinating parallel writes where a + parent process may pre-create output directories. + +Recommended multiprocess pattern: + +1. **Parent** pre-creates shard directories to validate filesystem writability. +2. Each **child process** writes a directory shard using + ``finalize_directory_persistent()``. +3. After finalization completes, child writes a sentinel file (e.g., ``SHARD_OK``) + to signal completion. +4. **Parent** waits for all ``SHARD_OK`` markers before merging shards. + +This pattern avoids race conditions where the parent checks for directory +existence while children are still writing. + +.. code-block:: cpp + + // Parent process: pre-create shard directories + for (size_t i = 0; i < num_shards; ++i) { + const std::string shard_path = "shards/shard_" + std::to_string(i); + std::filesystem::create_directories(shard_path); + } + + // Fork child processes... + +.. code-block:: cpp + + // Child process: write to pre-created directory + trx::TrxStream stream("float16"); + // ... push streamlines, dpv, dps, groups ... + stream.finalize_directory_persistent("/path/to/shards/shard_0"); + + // Signal completion to parent + std::ofstream ok("/path/to/shards/shard_0/SHARD_OK"); + ok << "ok\n"; + ok.close(); + +.. code-block:: cpp + + // Parent process (after waiting for all SHARD_OK markers) + // Merge by concatenating positions/DPV/DPS, adjusting offsets/groups. + // See bench/bench_trx_stream.cpp for a reference merge implementation. + +.. note:: + Use ``finalize_directory()`` for single-process writes where you want to + ensure a clean output state. Use ``finalize_directory_persistent()`` for + multiprocess workflows to avoid removing directories that may be checked + for existence by other processes. + +MRtrix-style write kernel (single-writer) +----------------------------------------- + +MRtrix uses a multi-threaded producer stage and a single-writer kernel to +serialize streamlines to disk. The same pattern works for TRX by letting the +writer own the ``TrxStream`` and accepting batches from the thread queue. + +.. code-block:: cpp + + #include + #include + #include + + struct TrxWriteKernel { + explicit TrxWriteKernel(const std::string &path) + : stream("float16"), out_path(path) {} + + void operator()(const std::vector>> &batch) { + for (const auto &points : batch) { + stream.push_streamline(points); + } + } + + void finalize() { + stream.finalize(out_path, ZIP_CM_STORE); + } + + private: + trx::TrxStream stream; + std::string out_path; + }; + +This kernel can be used as the final stage of a producer pipeline. The key rule +is: **only the writer thread touches ``TrxStream``**, while worker threads only +generate streamlines. + Optional NIfTI header support ----------------------------- diff --git a/include/trx/trx.h b/include/trx/trx.h index 03ab146..0bc1504 100644 --- a/include/trx/trx.h +++ b/include/trx/trx.h @@ -9,18 +9,24 @@ #include #include #include +#include +#include #include #include #include +#include #include #include #include #include #include #include +#include #include +#include #include #include +#include #include #include @@ -110,55 +116,68 @@ inline zip_t *open_zip_for_read(const std::string &path, int &errorp) { } template struct DTypeName { - static constexpr std::string_view value() { return "float16"; } + static constexpr bool supported = false; + static constexpr std::string_view value() { return ""; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "float16"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "float32"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "float64"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int8"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int16"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int32"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "int64"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint8"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint16"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint32"; } }; template <> struct DTypeName { + static constexpr bool supported = true; static constexpr std::string_view value() { return "uint64"; } }; template inline std::string dtype_from_scalar() { - typedef typename std::remove_cv::type>::type CleanT; + using CleanT = std::remove_cv_t>; + static_assert(DTypeName::supported, "Unsupported dtype for TRX scalar."); return std::string(DTypeName::value()); } @@ -236,6 +255,7 @@ template class TrxFile { std::string root = ""); template friend class TrxReader; + template friend std::unique_ptr> load(const std::string &path); /** * @brief Create a deepcopy of the TrxFile @@ -550,6 +570,8 @@ struct TypedArray { } }; +enum class TrxScalarType; + class AnyTrxFile { public: AnyTrxFile() = default; @@ -576,6 +598,14 @@ class AnyTrxFile { void close(); void save(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + using PositionsChunkCallback = + std::function; + using PositionsChunkMutableCallback = + std::function; + + void for_each_positions_chunk(size_t chunk_bytes, const PositionsChunkCallback &fn) const; + void for_each_positions_chunk_mutable(size_t chunk_bytes, const PositionsChunkMutableCallback &fn); + static AnyTrxFile load(const std::string &path); static AnyTrxFile load_from_zip(const std::string &path); static AnyTrxFile load_from_directory(const std::string &path); @@ -628,6 +658,42 @@ class TrxStream { */ void push_streamline(const std::vector> &points); + /** + * @brief Set max in-memory position buffer size (bytes). + * + * When set to a non-zero value, positions are buffered in memory and flushed + * to the temp file once the buffer reaches this size. Useful for reducing + * small I/O writes on slow disks. + */ + void set_positions_buffer_max_bytes(std::size_t max_bytes); + + enum class MetadataMode { InMemory, OnDisk }; + + /** + * @brief Control how DPS/DPV/groups are stored during streaming. + * + * InMemory keeps metadata in RAM until finalize (default). + * OnDisk writes metadata to temp files and copies them at finalize. + */ + void set_metadata_mode(MetadataMode mode); + + /** + * @brief Set max in-memory buffer size for metadata writes (bytes). + * + * Applies when MetadataMode::OnDisk. Larger buffers reduce write calls. + */ + void set_metadata_buffer_max_bytes(std::size_t max_bytes); + + /** + * @brief Set the VOXEL_TO_RASMM affine matrix in the header. + */ + void set_voxel_to_rasmm(const Eigen::Matrix4f &affine); + + /** + * @brief Set DIMENSIONS in the header. + */ + void set_dimensions(const std::array &dims); + /** * @brief Add per-streamline values (DPS) from an in-memory vector. */ @@ -648,6 +714,54 @@ class TrxStream { */ template void finalize(const std::string &filename, zip_uint32_t compression_standard = ZIP_CM_STORE); + /** + * @brief Finalize and write a TRX directory (no zip). + * + * This method removes any existing directory at the output path before + * writing. Use this for single-process writes or when you control the + * entire output location lifecycle. + * + * @param directory Path where the uncompressed TRX directory will be created. + * + * @throws std::runtime_error if already finalized or if I/O fails. + * + * @see finalize_directory_persistent for multiprocess-safe variant. + */ + void finalize_directory(const std::string &directory); + + /** + * @brief Finalize and write a TRX directory without removing existing files. + * + * This variant is designed for multiprocess workflows where the output + * directory is pre-created by a parent process. Unlike finalize_directory(), + * this method does NOT remove the output directory if it exists, making it + * safe for coordinated parallel writes where multiple processes may check + * for the directory's existence. + * + * @param directory Path where the uncompressed TRX directory will be created. + * If the directory exists, its contents will be overwritten + * but the directory itself will not be removed and recreated. + * + * @throws std::runtime_error if already finalized or if I/O fails. + * + * @note Typical usage pattern: + * @code + * // Parent process creates shard directories + * fs::create_directories("shards/shard_0"); + * + * // Child process writes without removing directory + * trx::TrxStream stream("float16"); + * // ... push streamlines ... + * stream.finalize_directory_persistent("shards/shard_0"); + * std::ofstream("shards/shard_0/SHARD_OK") << "ok\n"; + * + * // Parent waits for SHARD_OK before reading results + * @endcode + * + * @see finalize_directory for single-process variant that ensures clean slate. + */ + void finalize_directory_persistent(const std::string &directory); + size_t num_streamlines() const { return lengths_.size(); } size_t num_vertices() const { return total_vertices_; } @@ -659,13 +773,24 @@ class TrxStream { std::vector values; }; + struct MetadataFile { + std::string relative_path; + std::string absolute_path; + }; + void ensure_positions_stream(); + void flush_positions_buffer(); void cleanup_tmp(); + void ensure_metadata_dir(const std::string &subdir); + void finalize_directory_impl(const std::string &directory, bool remove_existing); std::string positions_dtype_; std::string tmp_dir_; std::string positions_path_; std::ofstream positions_out_; + std::vector positions_buffer_float_; + std::vector positions_buffer_half_; + std::size_t positions_buffer_max_entries_ = 0; std::vector lengths_; size_t total_vertices_ = 0; bool finalized_ = false; @@ -673,6 +798,9 @@ class TrxStream { std::map> groups_; std::map dps_; std::map dpv_; + MetadataMode metadata_mode_ = MetadataMode::InMemory; + std::vector metadata_files_; + std::size_t metadata_buffer_max_bytes_ = 8 * 1024 * 1024; }; /** @@ -740,6 +868,22 @@ inline std::string scalar_type_name(TrxScalarType dtype) { } } +struct PositionsOutputInfo { + std::string directory; + std::string positions_path; + std::string dtype; + size_t points = 0; +}; + +/** + * @brief Prepare an output directory with copied metadata and offsets. + * + * Creates a new TRX directory (no zip) that contains header, offsets, and + * metadata (groups, dps, dpv, dpg), and returns where the positions file + * should be written. + */ +PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, const std::string &output_directory); + /** * @brief Detect the positions scalar type for a TRX path. * @@ -877,7 +1021,8 @@ void ediff1d(Eigen::Matrix &lengths, void zip_from_folder(zip_t *zf, const std::string &root, const std::string &directory, - zip_uint32_t compression_standard = ZIP_CM_STORE); + zip_uint32_t compression_standard = ZIP_CM_STORE, + const std::unordered_set *skip = nullptr); std::string get_base(const std::string &delimiter, const std::string &str); std::string get_ext(const std::string &str); diff --git a/include/trx/trx.tpp b/include/trx/trx.tpp index f0c8079..07da736 100644 --- a/include/trx/trx.tpp +++ b/include/trx/trx.tpp @@ -169,7 +169,7 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve offsets_dtype = dtype_from_scalar(); lengths_dtype = dtype_from_scalar(); } else { - positions_dtype = dtype_from_scalar(); + positions_dtype = dtype_from_scalar
(); offsets_dtype = dtype_from_scalar(); lengths_dtype = dtype_from_scalar(); } @@ -181,8 +181,7 @@ std::unique_ptr> _initialize_empty_trx(int nb_streamlines, int nb_ve trx->streamlines = std::make_unique>(); trx->streamlines->mmap_pos = trx::_create_memmap(positions_filename, shape, "w+", positions_dtype); - // TODO: find a better way to get the dtype than using all these switch cases. Also refactor - // into function as per specifications, positions can only be floats + // TODO: find a better way to get the dtype than using all these switch cases. if (positions_dtype.compare("float16") == 0) { new (&(trx->streamlines->_data)) Map>( reinterpret_cast(trx->streamlines->mmap_pos.data()), std::get<0>(shape), std::get<1>(shape)); @@ -356,9 +355,12 @@ TrxFile
::_create_trx_from_pointer(json header, long long size = std::get<1>(x->second); if (base.compare("positions") == 0 && (folder.compare("") == 0 || folder.compare(".") == 0)) { - if (size != static_cast(trx->header["NB_VERTICES"].int_value()) * 3 || dim != 3) { - - throw std::invalid_argument("Wrong data size/dimensionality"); + const auto nb_vertices = static_cast(trx->header["NB_VERTICES"].int_value()); + const auto expected = nb_vertices * 3; + if (size != expected || dim != 3) { + throw std::invalid_argument("Wrong data size/dimensionality: size=" + std::to_string(size) + + " expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) + + " filename=" + elem_filename); } std::tuple shape = std::make_tuple(static_cast(trx->header["NB_VERTICES"].int_value()), 3); @@ -380,11 +382,12 @@ TrxFile
::_create_trx_from_pointer(json header, } else if (base.compare("offsets") == 0 && (folder.compare("") == 0 || folder.compare(".") == 0)) { - if (size != static_cast(trx->header["NB_STREAMLINES"].int_value()) + 1 || dim != 1) { - throw std::invalid_argument( - "Wrong offsets size/dimensionality: size=" + std::to_string(size) + - " nb_streamlines=" + std::to_string(static_cast(trx->header["NB_STREAMLINES"].int_value())) + - " dim=" + std::to_string(dim) + " filename=" + elem_filename); + const auto nb_streamlines = static_cast(trx->header["NB_STREAMLINES"].int_value()); + const auto expected = nb_streamlines + 1; + if (size != expected || dim != 1) { + throw std::invalid_argument("Wrong offsets size/dimensionality: size=" + std::to_string(size) + + " expected=" + std::to_string(expected) + " dim=" + std::to_string(dim) + + " filename=" + elem_filename); } const int nb_str = static_cast(trx->header["NB_STREAMLINES"].int_value()); @@ -965,9 +968,42 @@ template std::unique_ptr> TrxFile
::load_from_direc std::string header_name = directory + SEPARATOR + "header.json"; // TODO: add check to verify that it's open - std::ifstream header_file(header_name); + std::ifstream header_file; + for (int attempt = 0; attempt < 5; ++attempt) { + header_file.open(header_name); + if (header_file.is_open()) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } if (!header_file.is_open()) { - throw std::runtime_error("Failed to open header.json at: " + header_name); + std::error_code ec; + const bool exists = trx::fs::exists(directory, ec); + const int open_err = errno; + std::string detail = "Failed to open header.json at: " + header_name; + detail += " exists=" + std::string(exists ? "true" : "false"); + detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err)); + if (exists) { + std::vector files; + for (const auto &entry : trx::fs::directory_iterator(directory, ec)) { + if (ec) { + break; + } + files.push_back(entry.path().filename().string()); + } + if (!files.empty()) { + std::sort(files.begin(), files.end()); + detail += " files=["; + for (size_t i = 0; i < files.size(); ++i) { + if (i > 0) { + detail += ","; + } + detail += files[i]; + } + detail += "]"; + } + } + throw std::runtime_error(detail); } std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); header_file.close(); @@ -995,6 +1031,10 @@ template std::unique_ptr> TrxFile
::load(const std: return TrxFile
::load_from_zip(path); } +template std::unique_ptr> load(const std::string &path) { + return TrxFile
::load(path); +} + template TrxReader
::TrxReader(const std::string &path) { trx_ = TrxFile
::load(path); } template TrxReader
::TrxReader(TrxReader &&other) noexcept : trx_(std::move(other.trx_)) {} @@ -1067,13 +1107,58 @@ template void TrxFile
::save(const std::string &filename, zip_u } std::string tmp_dir_name = copy_trx->_uncompressed_folder_handle; + if (!tmp_dir_name.empty()) { + const std::string header_path = tmp_dir_name + SEPARATOR + "header.json"; + std::ofstream out_json(header_path, std::ios::out | std::ios::trunc); + if (!out_json.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path); + } + out_json << copy_trx->header.dump() << std::endl; + out_json.close(); + } + if (ext.size() > 0 && (ext == "zip" || ext == "trx")) { + auto sync_unmap_seq = [](auto &seq) { + if (!seq) { + return; + } + std::error_code ec; + seq->mmap_pos.sync(ec); + seq->mmap_pos.unmap(); + seq->mmap_off.sync(ec); + seq->mmap_off.unmap(); + }; + auto sync_unmap_mat = [](auto &mat) { + if (!mat) { + return; + } + std::error_code ec; + mat->mmap.sync(ec); + mat->mmap.unmap(); + }; + + sync_unmap_seq(copy_trx->streamlines); + for (auto &kv : copy_trx->groups) { + sync_unmap_mat(kv.second); + } + for (auto &kv : copy_trx->data_per_streamline) { + sync_unmap_mat(kv.second); + } + for (auto &kv : copy_trx->data_per_vertex) { + sync_unmap_seq(kv.second); + } + for (auto &group_kv : copy_trx->data_per_group) { + for (auto &kv : group_kv.second) { + sync_unmap_mat(kv.second); + } + } + int errorp; zip_t *zf; if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); } else { - zip_from_folder(zf, tmp_dir_name, tmp_dir_name, compression_standard); + zip_from_folder(zf, tmp_dir_name, tmp_dir_name, compression_standard, nullptr); if (zip_close(zf) != 0) { throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); } @@ -1388,8 +1473,8 @@ inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std: std::transform(positions_dtype_.begin(), positions_dtype_.end(), positions_dtype_.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - if (positions_dtype_ != "float32") { - throw std::invalid_argument("TrxStream only supports float32 positions for now"); + if (positions_dtype_ != "float32" && positions_dtype_ != "float16") { + throw std::invalid_argument("TrxStream only supports float16/float32 positions for now"); } tmp_dir_ = make_temp_dir("trx_proto"); positions_path_ = tmp_dir_ + SEPARATOR + "positions.tmp"; @@ -1398,6 +1483,20 @@ inline TrxStream::TrxStream(std::string positions_dtype) : positions_dtype_(std: inline TrxStream::~TrxStream() { cleanup_tmp(); } +inline void TrxStream::set_metadata_mode(MetadataMode mode) { + if (finalized_) { + throw std::runtime_error("Cannot adjust metadata mode after finalize"); + } + metadata_mode_ = mode; +} + +inline void TrxStream::set_metadata_buffer_max_bytes(std::size_t max_bytes) { + if (finalized_) { + throw std::runtime_error("Cannot adjust metadata buffer after finalize"); + } + metadata_buffer_max_bytes_ = max_bytes; +} + inline void TrxStream::ensure_positions_stream() { if (!positions_out_.is_open()) { positions_out_.open(positions_path_, std::ios::binary | std::ios::out | std::ios::trunc); @@ -1407,7 +1506,50 @@ inline void TrxStream::ensure_positions_stream() { } } +inline void TrxStream::ensure_metadata_dir(const std::string &subdir) { + if (tmp_dir_.empty()) { + throw std::runtime_error("TrxStream temp directory not initialized"); + } + const std::string dir = tmp_dir_ + SEPARATOR + subdir + SEPARATOR; + std::error_code ec; + trx::fs::create_directories(dir, ec); + if (ec) { + throw std::runtime_error("Could not create directory " + dir); + } +} + +inline void TrxStream::flush_positions_buffer() { + if (positions_dtype_ == "float16") { + if (positions_buffer_half_.empty()) { + return; + } + ensure_positions_stream(); + const size_t byte_count = positions_buffer_half_.size() * sizeof(half); + positions_out_.write(reinterpret_cast(positions_buffer_half_.data()), + static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions buffer"); + } + positions_buffer_half_.clear(); + return; + } + + if (positions_buffer_float_.empty()) { + return; + } + ensure_positions_stream(); + const size_t byte_count = positions_buffer_float_.size() * sizeof(float); + positions_out_.write(reinterpret_cast(positions_buffer_float_.data()), + static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions buffer"); + } + positions_buffer_float_.clear(); +} + inline void TrxStream::cleanup_tmp() { + positions_buffer_float_.clear(); + positions_buffer_half_.clear(); if (positions_out_.is_open()) { positions_out_.close(); } @@ -1425,11 +1567,42 @@ inline void TrxStream::push_streamline(const float *xyz, size_t point_count) { lengths_.push_back(0); return; } - ensure_positions_stream(); - const size_t byte_count = point_count * 3 * sizeof(float); - positions_out_.write(reinterpret_cast(xyz), static_cast(byte_count)); - if (!positions_out_) { - throw std::runtime_error("Failed to write TrxStream positions"); + if (positions_buffer_max_entries_ == 0) { + ensure_positions_stream(); + if (positions_dtype_ == "float16") { + std::vector tmp; + tmp.reserve(point_count * 3); + for (size_t i = 0; i < point_count * 3; ++i) { + tmp.push_back(static_cast(xyz[i])); + } + const size_t byte_count = tmp.size() * sizeof(half); + positions_out_.write(reinterpret_cast(tmp.data()), static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions"); + } + } else { + const size_t byte_count = point_count * 3 * sizeof(float); + positions_out_.write(reinterpret_cast(xyz), static_cast(byte_count)); + if (!positions_out_) { + throw std::runtime_error("Failed to write TrxStream positions"); + } + } + } else { + const size_t floats_count = point_count * 3; + if (positions_dtype_ == "float16") { + positions_buffer_half_.reserve(positions_buffer_half_.size() + floats_count); + for (size_t i = 0; i < floats_count; ++i) { + positions_buffer_half_.push_back(static_cast(xyz[i])); + } + if (positions_buffer_half_.size() >= positions_buffer_max_entries_) { + flush_positions_buffer(); + } + } else { + positions_buffer_float_.insert(positions_buffer_float_.end(), xyz, xyz + floats_count); + if (positions_buffer_float_.size() >= positions_buffer_max_entries_) { + flush_positions_buffer(); + } + } } total_vertices_ += point_count; lengths_.push_back(static_cast(point_count)); @@ -1443,7 +1616,32 @@ inline void TrxStream::push_streamline(const std::vector &xyz_flat) { } inline void TrxStream::push_streamline(const std::vector> &points) { - push_streamline(reinterpret_cast(points.data()), points.size()); + if (points.empty()) { + push_streamline(static_cast(nullptr), 0); + return; + } + std::vector xyz_flat; + xyz_flat.reserve(points.size() * 3); + for (const auto &point : points) { + xyz_flat.push_back(point[0]); + xyz_flat.push_back(point[1]); + xyz_flat.push_back(point[2]); + } + push_streamline(xyz_flat); +} + +inline void TrxStream::set_voxel_to_rasmm(const Eigen::Matrix4f &affine) { + std::vector> matrix(4, std::vector(4, 0.0f)); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + matrix[static_cast(i)][static_cast(j)] = affine(i, j); + } + } + header = _json_set(header, "VOXEL_TO_RASMM", matrix); +} + +inline void TrxStream::set_dimensions(const std::array &dims) { + header = _json_set(header, "DIMENSIONS", std::vector{dims[0], dims[1], dims[2]}); } template @@ -1462,13 +1660,67 @@ TrxStream::push_dps_from_vector(const std::string &name, const std::string &dtyp if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { throw std::invalid_argument("Unsupported DPS dtype: " + dtype); } - FieldValues field; - field.dtype = dtype_norm; - field.values.reserve(values.size()); - for (const auto &v : values) { - field.values.push_back(static_cast(v)); + if (metadata_mode_ == MetadataMode::OnDisk) { + ensure_metadata_dir("dps"); + const std::string filename = tmp_dir_ + SEPARATOR + "dps" + SEPARATOR + name + "." + dtype_norm; + std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open DPS file: " + filename); + } + if (dtype_norm == "float16") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(half))); + offset += count; + } + } else if (dtype_norm == "float32") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(float)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(float))); + offset += count; + } + } else { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(double)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(double))); + offset += count; + } + } + out.close(); + metadata_files_.push_back({std::string("dps") + SEPARATOR + name + "." + dtype_norm, filename}); + } else { + FieldValues field; + field.dtype = dtype_norm; + field.values.reserve(values.size()); + for (const auto &v : values) { + field.values.push_back(static_cast(v)); + } + dps_[name] = std::move(field); } - dps_[name] = std::move(field); } template @@ -1487,20 +1739,113 @@ TrxStream::push_dpv_from_vector(const std::string &name, const std::string &dtyp if (dtype_norm != "float16" && dtype_norm != "float32" && dtype_norm != "float64") { throw std::invalid_argument("Unsupported DPV dtype: " + dtype); } - FieldValues field; - field.dtype = dtype_norm; - field.values.reserve(values.size()); - for (const auto &v : values) { - field.values.push_back(static_cast(v)); + if (metadata_mode_ == MetadataMode::OnDisk) { + ensure_metadata_dir("dpv"); + const std::string filename = tmp_dir_ + SEPARATOR + "dpv" + SEPARATOR + name + "." + dtype_norm; + std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open DPV file: " + filename); + } + if (dtype_norm == "float16") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(half))); + offset += count; + } + } else if (dtype_norm == "float32") { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(float)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(float))); + offset += count; + } + } else { + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(double)); + std::vector tmp; + tmp.reserve(chunk_elems); + size_t offset = 0; + while (offset < values.size()) { + const size_t count = std::min(chunk_elems, values.size() - offset); + tmp.clear(); + for (size_t i = 0; i < count; ++i) { + tmp.push_back(static_cast(values[offset + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(count * sizeof(double))); + offset += count; + } + } + out.close(); + metadata_files_.push_back({std::string("dpv") + SEPARATOR + name + "." + dtype_norm, filename}); + } else { + FieldValues field; + field.dtype = dtype_norm; + field.values.reserve(values.size()); + for (const auto &v : values) { + field.values.push_back(static_cast(v)); + } + dpv_[name] = std::move(field); + } +} + +inline void TrxStream::set_positions_buffer_max_bytes(std::size_t max_bytes) { + if (finalized_) { + throw std::runtime_error("Cannot adjust buffer after finalize"); + } + if (max_bytes == 0) { + positions_buffer_max_entries_ = 0; + positions_buffer_float_.clear(); + positions_buffer_half_.clear(); + return; + } + const std::size_t element_size = positions_dtype_ == "float16" ? sizeof(half) : sizeof(float); + const std::size_t entries = max_bytes / element_size; + const std::size_t aligned = (entries / 3) * 3; + positions_buffer_max_entries_ = aligned; + if (positions_buffer_max_entries_ == 0) { + positions_buffer_float_.clear(); + positions_buffer_half_.clear(); } - dpv_[name] = std::move(field); } inline void TrxStream::push_group_from_indices(const std::string &name, const std::vector &indices) { if (name.empty()) { throw std::invalid_argument("Group name cannot be empty"); } - groups_[name] = indices; + if (metadata_mode_ == MetadataMode::OnDisk) { + ensure_metadata_dir("groups"); + const std::string filename = tmp_dir_ + SEPARATOR + "groups" + SEPARATOR + name + ".uint32"; + std::ofstream out(filename, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open group file: " + filename); + } + const size_t chunk_elems = std::max(1, metadata_buffer_max_bytes_ / sizeof(uint32_t)); + size_t offset = 0; + while (offset < indices.size()) { + const size_t count = std::min(chunk_elems, indices.size() - offset); + out.write(reinterpret_cast(indices.data() + offset), + static_cast(count * sizeof(uint32_t))); + offset += count; + } + out.close(); + metadata_files_.push_back({std::string("groups") + SEPARATOR + name + ".uint32", filename}); + } else { + groups_[name] = indices; + } } template void TrxStream::finalize(const std::string &filename, zip_uint32_t compression_standard) { @@ -1509,6 +1854,7 @@ template void TrxStream::finalize(const std::string &filename, zip } finalized_ = true; + flush_positions_buffer(); if (positions_out_.is_open()) { positions_out_.flush(); positions_out_.close(); @@ -1539,14 +1885,25 @@ template void TrxStream::finalize(const std::string &filename, zip throw std::runtime_error("Failed to open TrxStream temp positions file for read: " + positions_path_); } for (size_t i = 0; i < nb_vertices; ++i) { - float xyz[3]; - in.read(reinterpret_cast(xyz), sizeof(xyz)); - if (!in) { - throw std::runtime_error("Failed to read TrxStream positions"); + if (positions_dtype_ == "float16") { + half xyz[3]; + in.read(reinterpret_cast(xyz), sizeof(xyz)); + if (!in) { + throw std::runtime_error("Failed to read TrxStream positions"); + } + positions(static_cast(i), 0) = static_cast
(xyz[0]); + positions(static_cast(i), 1) = static_cast
(xyz[1]); + positions(static_cast(i), 2) = static_cast
(xyz[2]); + } else { + float xyz[3]; + in.read(reinterpret_cast(xyz), sizeof(xyz)); + if (!in) { + throw std::runtime_error("Failed to read TrxStream positions"); + } + positions(static_cast(i), 0) = static_cast
(xyz[0]); + positions(static_cast(i), 1) = static_cast
(xyz[1]); + positions(static_cast(i), 2) = static_cast
(xyz[2]); } - positions(static_cast(i), 0) = static_cast
(xyz[0]); - positions(static_cast(i), 1) = static_cast
(xyz[1]); - positions(static_cast(i), 2) = static_cast
(xyz[2]); } for (const auto &kv : dps_) { @@ -1559,12 +1916,205 @@ template void TrxStream::finalize(const std::string &filename, zip trx.add_group_from_indices(kv.first, kv.second); } + if (metadata_mode_ == MetadataMode::OnDisk) { + for (const auto &meta : metadata_files_) { + const std::string dest = trx._uncompressed_folder_handle + SEPARATOR + meta.relative_path; + const trx::fs::path dest_path(dest); + if (dest_path.has_parent_path()) { + std::error_code parent_ec; + trx::fs::create_directories(dest_path.parent_path(), parent_ec); + } + std::error_code copy_ec; + trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec); + if (copy_ec) { + throw std::runtime_error("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); + } + } + } + trx.save(filename, compression_standard); trx.close(); cleanup_tmp(); } +inline void TrxStream::finalize_directory_impl(const std::string &directory, bool remove_existing) { + if (finalized_) { + throw std::runtime_error("TrxStream already finalized"); + } + finalized_ = true; + + flush_positions_buffer(); + if (positions_out_.is_open()) { + positions_out_.flush(); + positions_out_.close(); + } + + const size_t nb_streamlines = lengths_.size(); + const size_t nb_vertices = total_vertices_; + + std::error_code ec; + if (remove_existing && trx::fs::exists(directory, ec)) { + trx::fs::remove_all(directory, ec); + ec.clear(); + } + + // Create directory if it doesn't exist + if (!trx::fs::exists(directory, ec)) { + trx::fs::create_directories(directory, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + directory); + } + } + ec.clear(); + + json header_out = header; + header_out = _json_set(header_out, "NB_VERTICES", static_cast(nb_vertices)); + header_out = _json_set(header_out, "NB_STREAMLINES", static_cast(nb_streamlines)); + const std::string header_path = directory + SEPARATOR + "header.json"; + std::ofstream out_header(header_path, std::ios::out | std::ios::trunc); + if (!out_header.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path); + } + out_header << header_out.dump() << std::endl; + out_header.close(); + + const std::string positions_name = "positions.3." + positions_dtype_; + const std::string positions_dst = directory + SEPARATOR + positions_name; + trx::fs::rename(positions_path_, positions_dst, ec); + if (ec) { + ec.clear(); + trx::fs::copy_file(positions_path_, positions_dst, trx::fs::copy_options::overwrite_existing, ec); + if (ec) { + throw std::runtime_error("Failed to copy positions file to: " + positions_dst); + } + } + + const std::string offsets_dst = directory + SEPARATOR + "offsets.uint64"; + std::ofstream offsets_out(offsets_dst, std::ios::binary | std::ios::out | std::ios::trunc); + if (!offsets_out.is_open()) { + throw std::runtime_error("Failed to open offsets file for write: " + offsets_dst); + } + uint64_t offset = 0; + offsets_out.write(reinterpret_cast(&offset), sizeof(offset)); + for (const auto length : lengths_) { + offset += static_cast(length); + offsets_out.write(reinterpret_cast(&offset), sizeof(offset)); + } + offsets_out.flush(); + offsets_out.close(); + + auto write_field_values = [&](const std::string &path, const FieldValues &values) { + std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open metadata file: " + path); + } + const size_t count = values.values.size(); + if (values.dtype == "float16") { + const size_t chunk = std::max(1, metadata_buffer_max_bytes_ / sizeof(half)); + std::vector tmp; + tmp.reserve(chunk); + size_t idx = 0; + while (idx < count) { + const size_t n = std::min(chunk, count - idx); + tmp.clear(); + for (size_t i = 0; i < n; ++i) { + tmp.push_back(static_cast(values.values[idx + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(n * sizeof(half))); + idx += n; + } + } else if (values.dtype == "float32") { + const size_t chunk = std::max(1, metadata_buffer_max_bytes_ / sizeof(float)); + std::vector tmp; + tmp.reserve(chunk); + size_t idx = 0; + while (idx < count) { + const size_t n = std::min(chunk, count - idx); + tmp.clear(); + for (size_t i = 0; i < n; ++i) { + tmp.push_back(static_cast(values.values[idx + i])); + } + out.write(reinterpret_cast(tmp.data()), static_cast(n * sizeof(float))); + idx += n; + } + } else if (values.dtype == "float64") { + const size_t chunk = std::max(1, metadata_buffer_max_bytes_ / sizeof(double)); + std::vector tmp; + tmp.reserve(chunk); + size_t idx = 0; + while (idx < count) { + const size_t n = std::min(chunk, count - idx); + tmp.clear(); + for (size_t i = 0; i < n; ++i) { + tmp.push_back(values.values[idx + i]); + } + out.write(reinterpret_cast(tmp.data()), static_cast(n * sizeof(double))); + idx += n; + } + } else { + throw std::runtime_error("Unsupported metadata dtype: " + values.dtype); + } + out.close(); + }; + + if (metadata_mode_ == MetadataMode::OnDisk) { + for (const auto &meta : metadata_files_) { + const std::string dest = directory + SEPARATOR + meta.relative_path; + const trx::fs::path dest_path(dest); + if (dest_path.has_parent_path()) { + std::error_code parent_ec; + trx::fs::create_directories(dest_path.parent_path(), parent_ec); + } + std::error_code copy_ec; + trx::fs::copy_file(meta.absolute_path, dest, trx::fs::copy_options::overwrite_existing, copy_ec); + if (copy_ec) { + throw std::runtime_error("Failed to copy metadata file: " + meta.absolute_path + " -> " + dest); + } + } + } else { + if (!dps_.empty()) { + trx::fs::create_directories(directory + SEPARATOR + "dps", ec); + for (const auto &kv : dps_) { + const std::string path = directory + SEPARATOR + "dps" + SEPARATOR + kv.first + "." + kv.second.dtype; + write_field_values(path, kv.second); + } + } + if (!dpv_.empty()) { + trx::fs::create_directories(directory + SEPARATOR + "dpv", ec); + for (const auto &kv : dpv_) { + const std::string path = directory + SEPARATOR + "dpv" + SEPARATOR + kv.first + "." + kv.second.dtype; + write_field_values(path, kv.second); + } + } + if (!groups_.empty()) { + trx::fs::create_directories(directory + SEPARATOR + "groups", ec); + for (const auto &kv : groups_) { + const std::string path = directory + SEPARATOR + "groups" + SEPARATOR + kv.first + ".uint32"; + std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open group file: " + path); + } + if (!kv.second.empty()) { + out.write(reinterpret_cast(kv.second.data()), + static_cast(kv.second.size() * sizeof(uint32_t))); + } + out.close(); + } + } + } + + cleanup_tmp(); +} + +inline void TrxStream::finalize_directory(const std::string &directory) { + finalize_directory_impl(directory, true); +} + +inline void TrxStream::finalize_directory_persistent(const std::string &directory) { + finalize_directory_impl(directory, false); +} + template void TrxFile
::add_dpv_from_tsf(const std::string &name, const std::string &dtype, const std::string &path) { if (name.empty()) { diff --git a/src/trx.cpp b/src/trx.cpp index 5762b34..d3f81fe 100644 --- a/src/trx.cpp +++ b/src/trx.cpp @@ -16,12 +16,18 @@ #include #include #include +#include #include #include #include #include #include #include +#if defined(_WIN32) || defined(_WIN64) +#include +#else +#include +#endif #include #include @@ -268,9 +274,42 @@ AnyTrxFile AnyTrxFile::load_from_directory(const std::string &path) { } std::string header_name = directory + SEPARATOR + "header.json"; - std::ifstream header_file(header_name); + std::ifstream header_file; + for (int attempt = 0; attempt < 5; ++attempt) { + header_file.open(header_name); + if (header_file.is_open()) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } if (!header_file.is_open()) { - throw std::runtime_error("Failed to open header.json at: " + header_name); + std::error_code ec; + const bool exists = trx::fs::exists(directory, ec); + const int open_err = errno; + std::string detail = "Failed to open header.json at: " + header_name; + detail += " exists=" + std::string(exists ? "true" : "false"); + detail += " errno=" + std::to_string(open_err) + " msg=" + std::string(std::strerror(open_err)); + if (exists) { + std::vector files; + for (const auto &entry : trx::fs::directory_iterator(directory, ec)) { + if (ec) { + break; + } + files.push_back(entry.path().filename().string()); + } + if (!files.empty()) { + std::sort(files.begin(), files.end()); + detail += " files=["; + for (size_t i = 0; i < files.size(); ++i) { + if (i > 0) { + detail += ","; + } + detail += files[i]; + } + detail += "]"; + } + } + throw std::runtime_error(detail); } std::string jstream((std::istreambuf_iterator(header_file)), std::istreambuf_iterator()); header_file.close(); @@ -461,35 +500,41 @@ void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_stan throw std::runtime_error("TRX file has no backing directory to save from"); } - std::string tmp_dir = make_temp_dir("trx_runtime"); - copy_dir(source_dir, tmp_dir); - - { - const trx::fs::path header_path = trx::fs::path(tmp_dir) / "header.json"; - std::ofstream out_json(header_path); - if (!out_json.is_open()) { - throw std::runtime_error("Failed to write header.json to: " + header_path.string()); - } - out_json << header.dump() << std::endl; - } - if (ext.size() > 0 && (ext == "zip" || ext == "trx")) { int errorp; zip_t *zf; if ((zf = zip_open(filename.c_str(), ZIP_CREATE + ZIP_TRUNCATE, &errorp)) == nullptr) { - rm_dir(tmp_dir); throw std::runtime_error("Could not open archive " + filename + ": " + strerror(errorp)); } - zip_from_folder(zf, tmp_dir, tmp_dir, compression_standard); + + const std::string header_payload = header.dump() + "\n"; + zip_source_t *header_source = + zip_source_buffer(zf, header_payload.data(), header_payload.size(), 0 /* do not free */); + if (header_source == nullptr) { + zip_close(zf); + throw std::runtime_error("Failed to create zip source for header.json: " + std::string(zip_strerror(zf))); + } + const zip_int64_t header_idx = zip_file_add(zf, "header.json", header_source, ZIP_FL_ENC_UTF_8 | ZIP_FL_OVERWRITE); + if (header_idx < 0) { + zip_source_free(header_source); + zip_close(zf); + throw std::runtime_error("Failed to add header.json to archive: " + std::string(zip_strerror(zf))); + } + const zip_int32_t compression = static_cast(compression_standard); + if (zip_set_file_compression(zf, header_idx, compression, 0) < 0) { + zip_close(zf); + throw std::runtime_error("Failed to set compression for header.json: " + std::string(zip_strerror(zf))); + } + + const std::unordered_set skip = {"header.json"}; + zip_from_folder(zf, source_dir, source_dir, compression_standard, &skip); if (zip_close(zf) != 0) { - rm_dir(tmp_dir); throw std::runtime_error("Unable to close archive " + filename + ": " + zip_strerror(zf)); } } else { std::error_code ec; if (trx::fs::exists(filename, ec) && trx::fs::is_directory(filename, ec)) { if (rm_dir(filename) != 0) { - rm_dir(tmp_dir); throw std::runtime_error("Could not remove existing directory " + filename); } } @@ -498,24 +543,29 @@ void AnyTrxFile::save(const std::string &filename, zip_uint32_t compression_stan std::error_code parent_ec; trx::fs::create_directories(dest_path.parent_path(), parent_ec); if (parent_ec) { - rm_dir(tmp_dir); throw std::runtime_error("Could not create output parent directory: " + dest_path.parent_path().string()); } } + std::string tmp_dir = make_temp_dir("trx_runtime"); + copy_dir(source_dir, tmp_dir); + const trx::fs::path tmp_header_path = trx::fs::path(tmp_dir) / "header.json"; + std::ofstream out_json(tmp_header_path); + if (!out_json.is_open()) { + rm_dir(tmp_dir); + throw std::runtime_error("Failed to write header.json to: " + tmp_header_path.string()); + } + out_json << header.dump() << std::endl; copy_dir(tmp_dir, filename); + rm_dir(tmp_dir); ec.clear(); if (!trx::fs::exists(filename, ec) || !trx::fs::is_directory(filename, ec)) { - rm_dir(tmp_dir); throw std::runtime_error("Failed to create output directory: " + filename); } - const trx::fs::path header_path = dest_path / "header.json"; - if (!trx::fs::exists(header_path)) { - rm_dir(tmp_dir); - throw std::runtime_error("Missing header.json in output directory: " + header_path.string()); + const trx::fs::path final_header_path = dest_path / "header.json"; + if (!trx::fs::exists(final_header_path)) { + throw std::runtime_error("Missing header.json in output directory: " + final_header_path.string()); } } - - rm_dir(tmp_dir); } void populate_fps(const string &name, std::map> &files_pointer_size) { @@ -806,8 +856,15 @@ std::string make_temp_dir(const std::string &prefix) { static std::mt19937_64 rng(std::random_device{}()); std::uniform_int_distribution dist; + const uint64_t pid = +#if defined(_WIN32) || defined(_WIN64) + static_cast(_getpid()); +#else + static_cast(getpid()); +#endif for (int attempt = 0; attempt < 100; ++attempt) { - const trx::fs::path candidate = base_path / (prefix + "_" + std::to_string(dist(rng))); + const trx::fs::path candidate = + base_path / (prefix + "_" + std::to_string(pid) + "_" + std::to_string(dist(rng))); ec.clear(); if (trx::fs::create_directory(candidate, ec)) { return candidate.string(); @@ -906,7 +963,8 @@ std::string extract_zip_to_directory(zip_t *zfolder) { void zip_from_folder(zip_t *zf, const std::string &root, const std::string &directory, - zip_uint32_t compression_standard) { + zip_uint32_t compression_standard, + const std::unordered_set *skip) { std::error_code ec; for (trx::fs::recursive_directory_iterator it(directory, ec), end; it != end; it.increment(ec)) { if (ec) { @@ -928,6 +986,10 @@ void zip_from_folder(zip_t *zf, if (source == nullptr) { throw std::runtime_error(std::string("Error adding file ") + zip_fname + ": " + zip_strerror(zf)); } + if (skip && skip->find(zip_fname) != skip->end()) { + zip_source_free(source); + continue; + } const zip_int64_t file_idx = zip_file_add(zf, zip_fname.c_str(), source, ZIP_FL_ENC_UTF_8); if (file_idx < 0) { zip_source_free(source); @@ -949,4 +1011,170 @@ std::string rm_root(const std::string &root, const std::string &path) { } return stripped; } + +namespace { +TrxScalarType scalar_type_from_dtype(const std::string &dtype) { + if (dtype == "float16") { + return TrxScalarType::Float16; + } + if (dtype == "float32") { + return TrxScalarType::Float32; + } + if (dtype == "float64") { + return TrxScalarType::Float64; + } + return TrxScalarType::Float32; +} + +std::string typed_array_filename(const std::string &base, const TypedArray &arr) { + if (arr.cols <= 1) { + return base + "." + arr.dtype; + } + return base + "." + std::to_string(arr.cols) + "." + arr.dtype; +} + +void write_typed_array_file(const std::string &path, const TypedArray &arr) { + const auto bytes = arr.to_bytes(); + std::ofstream out(path, std::ios::binary | std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to open output file: " + path); + } + if (bytes.data && bytes.size > 0) { + out.write(reinterpret_cast(bytes.data), static_cast(bytes.size)); + } + out.flush(); + out.close(); +} +} // namespace + +void AnyTrxFile::for_each_positions_chunk(size_t chunk_bytes, const PositionsChunkCallback &fn) const { + if (positions.empty()) { + throw std::runtime_error("TRX positions are empty."); + } + if (positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + if (!fn) { + return; + } + const size_t elem_size = static_cast(detail::_sizeof_dtype(positions.dtype)); + const size_t bytes_per_point = elem_size * 3; + const size_t total_points = static_cast(positions.rows); + size_t points_per_chunk = 0; + if (chunk_bytes == 0) { + points_per_chunk = total_points; + } else { + points_per_chunk = std::max(1, chunk_bytes / bytes_per_point); + } + const auto bytes = positions.to_bytes(); + const auto *base = bytes.data; + const auto dtype = scalar_type_from_dtype(positions.dtype); + for (size_t offset = 0; offset < total_points; offset += points_per_chunk) { + const size_t count = std::min(points_per_chunk, total_points - offset); + const void *ptr = base + offset * bytes_per_point; + fn(dtype, ptr, offset, count); + } +} + +void AnyTrxFile::for_each_positions_chunk_mutable(size_t chunk_bytes, const PositionsChunkMutableCallback &fn) { + if (positions.empty()) { + throw std::runtime_error("TRX positions are empty."); + } + if (positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + if (!fn) { + return; + } + const size_t elem_size = static_cast(detail::_sizeof_dtype(positions.dtype)); + const size_t bytes_per_point = elem_size * 3; + const size_t total_points = static_cast(positions.rows); + size_t points_per_chunk = 0; + if (chunk_bytes == 0) { + points_per_chunk = total_points; + } else { + points_per_chunk = std::max(1, chunk_bytes / bytes_per_point); + } + const auto bytes = positions.to_bytes_mutable(); + auto *base = bytes.data; + const auto dtype = scalar_type_from_dtype(positions.dtype); + for (size_t offset = 0; offset < total_points; offset += points_per_chunk) { + const size_t count = std::min(points_per_chunk, total_points - offset); + void *ptr = base + offset * bytes_per_point; + fn(dtype, ptr, offset, count); + } +} + +PositionsOutputInfo prepare_positions_output(const AnyTrxFile &input, const std::string &output_directory) { + if (input.positions.empty() || input.offsets.empty()) { + throw std::runtime_error("Input TRX missing positions/offsets."); + } + if (input.positions.cols != 3) { + throw std::runtime_error("Positions must have 3 columns."); + } + + std::error_code ec; + if (trx::fs::exists(output_directory, ec)) { + trx::fs::remove_all(output_directory, ec); + } + ec.clear(); + trx::fs::create_directories(output_directory, ec); + if (ec) { + throw std::runtime_error("Failed to create output directory: " + output_directory); + } + + const std::string header_path = output_directory + SEPARATOR + "header.json"; + { + std::ofstream out(header_path, std::ios::out | std::ios::trunc); + if (!out.is_open()) { + throw std::runtime_error("Failed to write header.json to: " + header_path); + } + out << input.header.dump() << std::endl; + } + + write_typed_array_file(output_directory + SEPARATOR + typed_array_filename("offsets", input.offsets), input.offsets); + + if (!input.groups.empty()) { + const std::string groups_dir = output_directory + SEPARATOR + "groups"; + trx::fs::create_directories(groups_dir, ec); + for (const auto &kv : input.groups) { + write_typed_array_file(groups_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + + if (!input.data_per_streamline.empty()) { + const std::string dps_dir = output_directory + SEPARATOR + "dps"; + trx::fs::create_directories(dps_dir, ec); + for (const auto &kv : input.data_per_streamline) { + write_typed_array_file(dps_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + + if (!input.data_per_vertex.empty()) { + const std::string dpv_dir = output_directory + SEPARATOR + "dpv"; + trx::fs::create_directories(dpv_dir, ec); + for (const auto &kv : input.data_per_vertex) { + write_typed_array_file(dpv_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + + if (!input.data_per_group.empty()) { + const std::string dpg_dir = output_directory + SEPARATOR + "dpg"; + trx::fs::create_directories(dpg_dir, ec); + for (const auto &group_kv : input.data_per_group) { + const std::string group_dir = dpg_dir + SEPARATOR + group_kv.first; + trx::fs::create_directories(group_dir, ec); + for (const auto &kv : group_kv.second) { + write_typed_array_file(group_dir + SEPARATOR + typed_array_filename(kv.first, kv.second), kv.second); + } + } + } + + PositionsOutputInfo info; + info.directory = output_directory; + info.dtype = input.positions.dtype; + info.points = static_cast(input.positions.rows); + info.positions_path = output_directory + SEPARATOR + typed_array_filename("positions", input.positions); + return info; +} }; // namespace trx \ No newline at end of file diff --git a/tests/test_trx_mmap.cpp b/tests/test_trx_mmap.cpp index d4fb5ef..7eb443b 100644 --- a/tests/test_trx_mmap.cpp +++ b/tests/test_trx_mmap.cpp @@ -181,7 +181,7 @@ TestTrxFixture create_fixture() { if (zf == nullptr) { throw std::runtime_error("Failed to create trx zip file"); } - trx::zip_from_folder(zf, trx_dir.string(), trx_dir.string(), ZIP_CM_STORE); + trx::zip_from_folder(zf, trx_dir.string(), trx_dir.string(), ZIP_CM_STORE, nullptr); if (zip_close(zf) != 0) { throw std::runtime_error("Failed to close trx zip file"); } diff --git a/tests/test_trx_trxfile.cpp b/tests/test_trx_trxfile.cpp index 9bb96d9..43edd7b 100644 --- a/tests/test_trx_trxfile.cpp +++ b/tests/test_trx_trxfile.cpp @@ -292,6 +292,45 @@ TEST(TrxFileTpp, TrxStreamFinalize) { fs::remove_all(tmp_dir, ec); } +TEST(TrxFileTpp, QueryAabbCounts) { + constexpr int kStreamlineCount = 1000; + constexpr int kInsideCount = 250; + constexpr int kPointsPerStreamline = 5; + + const int nb_vertices = kStreamlineCount * kPointsPerStreamline; + trx::TrxFile trx(nb_vertices, kStreamlineCount); + + trx.streamlines->_offsets(0, 0) = 0; + for (int i = 0; i < kStreamlineCount; ++i) { + trx.streamlines->_lengths(i) = kPointsPerStreamline; + trx.streamlines->_offsets(i + 1, 0) = (i + 1) * kPointsPerStreamline; + } + + int cursor = 0; + for (int i = 0; i < kStreamlineCount; ++i) { + const bool inside = i < kInsideCount; + for (int p = 0; p < kPointsPerStreamline; ++p, ++cursor) { + if (inside) { + trx.streamlines->_data(cursor, 0) = -0.8f + 0.05f * static_cast(p); + trx.streamlines->_data(cursor, 1) = 0.3f + 0.1f * static_cast(p); + trx.streamlines->_data(cursor, 2) = 0.1f + 0.05f * static_cast(p); + } else { + trx.streamlines->_data(cursor, 0) = 0.0f; + trx.streamlines->_data(cursor, 1) = 0.0f; + trx.streamlines->_data(cursor, 2) = -1000.0f - static_cast(i); + } + } + } + + const std::array min_corner{ -0.9f, 0.2f, 0.05f }; + const std::array max_corner{ -0.1f, 1.1f, 0.55f }; + + auto subset = trx.query_aabb(min_corner, max_corner); + EXPECT_EQ(subset->num_streamlines(), static_cast(kInsideCount)); + EXPECT_EQ(subset->num_vertices(), static_cast(kInsideCount * kPointsPerStreamline)); + subset->close(); +} + // resize() with default arguments is a no-op when sizes already match. TEST(TrxFileTpp, ResizeNoChange) { const fs::path data_dir = create_float_trx_dir(); From 9c094dd96638f92166f5d52080c7465c4319b912 Mon Sep 17 00:00:00 2001 From: mattcieslak Date: Tue, 10 Feb 2026 11:01:21 -0500 Subject: [PATCH 2/3] update rst --- bench/bench_trx_stream.cpp | 29 ++++++++++--- docs/benchmarks.rst | 87 +++++++++++++++++++++++++++++++++++--- 2 files changed, 102 insertions(+), 14 deletions(-) diff --git a/bench/bench_trx_stream.cpp b/bench/bench_trx_stream.cpp index e664481..c507121 100644 --- a/bench/bench_trx_stream.cpp +++ b/bench/bench_trx_stream.cpp @@ -363,14 +363,23 @@ size_t group_count_for(GroupScenario scenario) { } } +// Compute position buffer size based on streamline count. +// For slow storage (spinning disks, network filesystems), set TRX_BENCH_BUFFER_MULTIPLIER +// to 2-8 to reduce I/O frequency at the cost of higher memory usage. +// Example: multiplier=4 scales 256 MB → 1 GB for 1M streamlines. std::size_t buffer_bytes_for_streamlines(std::size_t streamlines) { + std::size_t base_bytes; if (streamlines >= 5000000) { - return 2ULL * 1024ULL * 1024ULL * 1024ULL; - } - if (streamlines >= 1000000) { - return 256ULL * 1024ULL * 1024ULL; + base_bytes = 2ULL * 1024ULL * 1024ULL * 1024ULL; // 2 GB + } else if (streamlines >= 1000000) { + base_bytes = 256ULL * 1024ULL * 1024ULL; // 256 MB + } else { + base_bytes = 16ULL * 1024ULL * 1024ULL; // 16 MB } - return 16ULL * 1024ULL * 1024ULL; + + // Allow scaling buffer sizes for slower storage (HDD, NFS) to amortize I/O latency + const size_t multiplier = std::max(1, parse_env_size("TRX_BENCH_BUFFER_MULTIPLIER", 1)); + return base_bytes * multiplier; } std::vector streamlines_for_benchmarks() { @@ -693,7 +702,10 @@ TrxWriteStats run_trx_file_size(size_t streamlines, zip_uint32_t compression) { trx::TrxStream stream("float16"); stream.set_metadata_mode(trx::TrxStream::MetadataMode::OnDisk); - stream.set_metadata_buffer_max_bytes(64ULL * 1024ULL * 1024ULL); + + // Scale metadata buffer with TRX_BENCH_BUFFER_MULTIPLIER for slow storage + const size_t buffer_multiplier = std::max(1, parse_env_size("TRX_BENCH_BUFFER_MULTIPLIER", 1)); + stream.set_metadata_buffer_max_bytes(64ULL * 1024ULL * 1024ULL * buffer_multiplier); stream.set_positions_buffer_max_bytes(buffer_bytes_for_streamlines(streamlines)); const size_t threads = bench_threads(); @@ -830,7 +842,10 @@ TrxOnDisk build_trx_file_on_disk_single(size_t streamlines, bool finalize_to_directory = false) { trx::TrxStream stream("float16"); stream.set_metadata_mode(trx::TrxStream::MetadataMode::OnDisk); - stream.set_metadata_buffer_max_bytes(64ULL * 1024ULL * 1024ULL); + + // Scale buffers with TRX_BENCH_BUFFER_MULTIPLIER for slow storage + const size_t buffer_multiplier = std::max(1, parse_env_size("TRX_BENCH_BUFFER_MULTIPLIER", 1)); + stream.set_metadata_buffer_max_bytes(64ULL * 1024ULL * 1024ULL * buffer_multiplier); stream.set_positions_buffer_max_bytes(buffer_bytes_for_streamlines(streamlines)); const size_t progress_every = parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 0); diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index 3c5f468..e10ddfa 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -66,6 +66,30 @@ visualized. Distribution of per-slab query latency. +Performance characteristics +--------------------------- + +Benchmark results vary significantly based on storage performance: + +**SSD (solid-state drives):** +- **CPU-bound**: Disk writes complete faster than streamline generation +- High CPU utilization (~100%) +- Results reflect pure computational throughput + +**HDD (spinning disks):** +- **I/O-bound**: Disk writes are the bottleneck +- Low CPU utilization (~5-10%) +- Results reflect realistic workstation performance with storage latency + +Both scenarios are valuable. SSD results show the library's maximum throughput, +while HDD results show real-world performance on cost-effective storage. On +Linux, monitor I/O wait time with ``iostat -x 1`` to identify the bottleneck. + +For spinning disks or network filesystems, you may want to increase buffer sizes +to amortize I/O latency. Set ``TRX_BENCH_BUFFER_MULTIPLIER`` to use larger +buffers (e.g., ``TRX_BENCH_BUFFER_MULTIPLIER=4`` uses 4× the default buffer +sizes). + Running the benchmarks ---------------------- @@ -80,6 +104,12 @@ Build and run the benchmarks, then plot results with matplotlib: ./build/bench/bench_trx_stream \ --benchmark_out=bench/results.json \ --benchmark_out_format=json + + # For slower storage (HDD, NFS), use larger buffers: + TRX_BENCH_BUFFER_MULTIPLIER=4 \ + ./build/bench/bench_trx_stream \ + --benchmark_out=bench/results_hdd.json \ + --benchmark_out_format=json # Capture per-slab timings for query distributions. TRX_QUERY_TIMINGS_PATH=bench/query_timings.jsonl \ @@ -106,13 +136,56 @@ The query plot defaults to the "no groups, no DPV/DPS" case. Use ``--group-case``, ``--dpv``, and ``--dps`` in ``plot_bench.py`` to select other scenarios. -If zip compression is too slow or unstable for the largest datasets, set -``TRX_BENCH_SKIP_ZIP_AT`` (default 5000000) to skip compression for large -streamline counts. +Environment variables +--------------------- + +The benchmark suite supports several environment variables for customization: + +**Multiprocessing:** + +- ``TRX_BENCH_PROCESSES`` (default: 1): Number of processes for parallel shard + generation. Recommended: number of physical cores. +- ``TRX_BENCH_MP_MIN_STREAMLINES`` (default: 1000000): Minimum streamline count + to enable multiprocessing. Below this threshold, single-process mode is used. +- ``TRX_BENCH_KEEP_SHARDS`` (default: 0): Set to 1 to preserve shard directories + after merging for debugging. +- ``TRX_BENCH_SHARD_WAIT_MS`` (default: 10000): Timeout in milliseconds for + waiting for shard completion markers. + +**Buffering (for slow storage):** + +- ``TRX_BENCH_BUFFER_MULTIPLIER`` (default: 1): Scales position and metadata + buffer sizes. Use larger values (2-8) for spinning disks or network + filesystems to reduce I/O latency. Example: multiplier=4 uses 64 MB → 256 MB + for small datasets, 256 MB → 1 GB for 1M streamlines, 2 GB → 8 GB for 5M+ + streamlines. + +**Performance tuning:** + +- ``TRX_BENCH_THREADS`` (default: hardware_concurrency): Worker threads for + streamline generation within each process. +- ``TRX_BENCH_BATCH`` (default: 1000): Streamlines per batch in the producer- + consumer queue. +- ``TRX_BENCH_QUEUE_MAX`` (default: 8): Maximum batches in flight between + producers and consumer. + +**Dataset control:** + +- ``TRX_BENCH_ONLY_STREAMLINES`` (default: 0): If nonzero, benchmark only this + streamline count instead of the full range. +- ``TRX_BENCH_MAX_STREAMLINES`` (default: 10000000): Maximum streamline count + to benchmark. Use smaller values for faster iteration. +- ``TRX_BENCH_SKIP_ZIP_AT`` (default: 5000000): Skip zip compression for + streamline counts at or above this threshold. + +**Logging and diagnostics:** + +- ``TRX_BENCH_LOG`` (default: 0): Enable benchmark progress logging to stderr. +- ``TRX_BENCH_CHILD_LOG`` (default: 0): Enable logging from child processes in + multiprocess mode. +- ``TRX_BENCH_LOG_PROGRESS_EVERY`` (default: 0): Log progress every N + streamlines. When running with multiprocessing, the benchmark uses ``finalize_directory_persistent()`` to write shard outputs without removing -pre-created directories, avoiding race conditions in the parallel workflow. You -can keep shard outputs for debugging by setting ``TRX_BENCH_KEEP_SHARDS=1``. The -merge step waits for each shard to finish (via ``SHARD_OK`` files); adjust the -timeout with ``TRX_BENCH_SHARD_WAIT_MS`` if needed. +pre-created directories, avoiding race conditions in the parallel workflow. From 5ce0fb18934a5791c8230038b3f208a4343f47ae Mon Sep 17 00:00:00 2001 From: mattcieslak Date: Tue, 10 Feb 2026 20:50:33 -0500 Subject: [PATCH 3/3] add --verbose --- CMakeLists.txt | 6 +- bench/bench_trx_stream.cpp | 62 +++++- bench/plot_bench.R | 412 +++++++++++++++++++++++++++++++++++++ bench/plot_bench.py | 1 - 4 files changed, 471 insertions(+), 10 deletions(-) create mode 100755 bench/plot_bench.R diff --git a/CMakeLists.txt b/CMakeLists.txt index bf07eac..cf0f267 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,7 +69,11 @@ elseif(TARGET zip::zip) else() message(FATAL_ERROR "No suitable libzip target (expected libzip::libzip or zip::zip)") endif() -find_package(Eigen3 CONFIG QUIET) +# Prefer Eigen3_ROOT so -DEigen3_ROOT=/path/to/eigen-3.4 is used over system Eigen +if(Eigen3_ROOT) + list(PREPEND CMAKE_PREFIX_PATH "${Eigen3_ROOT}") +endif() +find_package(Eigen3 3.4 CONFIG QUIET) if (NOT Eigen3_FOUND) find_package(Eigen3 REQUIRED) # try module mode endif() diff --git a/bench/bench_trx_stream.cpp b/bench/bench_trx_stream.cpp index c507121..8029aed 100644 --- a/bench/bench_trx_stream.cpp +++ b/bench/bench_trx_stream.cpp @@ -1828,9 +1828,8 @@ static void ApplySizeArgs(benchmark::internal::Benchmark *bench) { } static void ApplyStreamArgs(benchmark::internal::Benchmark *bench) { - const std::array groups = {static_cast(GroupScenario::None), - static_cast(GroupScenario::Bundles), - static_cast(GroupScenario::Connectome)}; + const std::array groups = {static_cast(GroupScenario::None), + static_cast(GroupScenario::Bundles)}; const std::array flags = {0, 1}; const auto counts_desc = streamlines_for_benchmarks(); for (const auto count : counts_desc) { @@ -1845,9 +1844,8 @@ static void ApplyStreamArgs(benchmark::internal::Benchmark *bench) { } static void ApplyQueryArgs(benchmark::internal::Benchmark *bench) { - const std::array groups = {static_cast(GroupScenario::None), - static_cast(GroupScenario::Bundles), - static_cast(GroupScenario::Connectome)}; + const std::array groups = {static_cast(GroupScenario::None), + static_cast(GroupScenario::Bundles)}; const std::array flags = {0, 1}; const auto counts_desc = streamlines_for_benchmarks(); for (const auto count : counts_desc) { @@ -1877,8 +1875,56 @@ BENCHMARK(BM_TrxQueryAabb_Slabs) ->Unit(benchmark::kMillisecond); int main(int argc, char **argv) { - ::benchmark::Initialize(&argc, argv); - if (::benchmark::ReportUnrecognizedArguments(argc, argv)) { + // Parse custom flags before benchmark::Initialize + bool verbose = false; + bool show_help = false; + + // First pass: detect custom flags + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg == "--verbose" || arg == "-v") { + verbose = true; + } else if (arg == "--help-custom") { + show_help = true; + } + } + + if (show_help) { + std::cout << "\nCustom benchmark options:\n" + << " --verbose, -v Enable verbose progress logging (prints every 50k streamlines)\n" + << " Equivalent to: TRX_BENCH_LOG=1 TRX_BENCH_CHILD_LOG=1 \n" + << " TRX_BENCH_LOG_PROGRESS_EVERY=50000\n" + << " --help-custom Show this help message\n" + << "\nFor standard benchmark options, use --help\n" + << std::endl; + return 0; + } + + // Enable verbose logging if requested + if (verbose) { + setenv("TRX_BENCH_LOG", "1", 0); // Don't override if already set + setenv("TRX_BENCH_CHILD_LOG", "1", 0); + if (std::getenv("TRX_BENCH_LOG_PROGRESS_EVERY") == nullptr) { + setenv("TRX_BENCH_LOG_PROGRESS_EVERY", "50000", 1); + } + std::cerr << "[trx-bench] Verbose mode enabled (progress every " + << parse_env_size("TRX_BENCH_LOG_PROGRESS_EVERY", 50000) + << " streamlines)\n" << std::endl; + } + + // Second pass: remove custom flags from argv before passing to benchmark::Initialize + std::vector filtered_argv; + filtered_argv.push_back(argv[0]); // Keep program name + for (int i = 1; i < argc; ++i) { + const std::string arg = argv[i]; + if (arg != "--verbose" && arg != "-v" && arg != "--help-custom") { + filtered_argv.push_back(argv[i]); + } + } + int filtered_argc = static_cast(filtered_argv.size()); + + ::benchmark::Initialize(&filtered_argc, filtered_argv.data()); + if (::benchmark::ReportUnrecognizedArguments(filtered_argc, filtered_argv.data())) { return 1; } try { diff --git a/bench/plot_bench.R b/bench/plot_bench.R new file mode 100755 index 0000000..d0218db --- /dev/null +++ b/bench/plot_bench.R @@ -0,0 +1,412 @@ +#!/usr/bin/env Rscript +# +# plot_bench.R - Plot trx-cpp benchmark results with ggplot2 +# +# Usage: +# Rscript bench/plot_bench.R [--bench-dir DIR] [--out-dir DIR] [--help] +# +# This script automatically detects benchmark result files in the bench/ +# directory and generates plots for: +# - File sizes (BM_TrxFileSize_Float16) +# - Translate/write throughput (BM_TrxStream_TranslateWrite) +# - Query performance (BM_TrxQueryAabb_Slabs) +# +# Expected input files (searched in bench-dir): +# - results*.json: Main benchmark results (Google Benchmark JSON format) +# - query_timings.jsonl: Per-query timing distributions (JSONL format) +# - rss_samples.jsonl: Memory samples over time (JSONL format, optional) +# + +suppressPackageStartupMessages({ + library(jsonlite) + library(ggplot2) + library(dplyr) + library(tidyr) + library(scales) +}) + +# Constants +LENGTH_LABELS <- c( + "0" = "mixed", + "1" = "short (20-120mm)", + "2" = "medium (80-260mm)", + "3" = "long (200-500mm)" +) + +GROUP_LABELS <- c( + "0" = "no groups", + "1" = "bundle groups (80)" +) + +COMPRESSION_LABELS <- c( + "0" = "store (no zip)", + "1" = "zip deflate" +) + +#' Parse command line arguments +parse_args <- function() { + args <- commandArgs(trailingOnly = TRUE) + + bench_dir <- "bench" + out_dir <- "docs/_static/benchmarks" + + i <- 1 + while (i <= length(args)) { + if (args[i] == "--bench-dir") { + bench_dir <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--out-dir") { + out_dir <- args[i + 1] + i <- i + 2 + } else if (args[i] == "--help" || args[i] == "-h") { + cat("Usage: Rscript plot_bench.R [--bench-dir DIR] [--out-dir DIR]\n") + cat("\n") + cat("Options:\n") + cat(" --bench-dir DIR Directory containing benchmark JSON files (default: bench)\n") + cat(" --out-dir DIR Output directory for plots (default: docs/_static/benchmarks)\n") + cat(" --help, -h Show this help message\n") + quit(status = 0) + } else { + i <- i + 1 + } + } + + list(bench_dir = bench_dir, out_dir = out_dir) +} + +#' Convert benchmark time to milliseconds +time_to_ms <- function(bench) { + value <- bench$real_time + unit <- bench$time_unit + + multiplier <- switch(unit, + "ns" = 1e-6, + "us" = 1e-3, + "ms" = 1, + "s" = 1e3, + 1e-6 # default to nanoseconds + ) + + value * multiplier +} + +#' Extract base benchmark name +parse_base_name <- function(name) { + sub("/.*", "", name) +} + +#' Load all benchmark result JSON files from a directory +load_benchmarks <- function(bench_dir) { + json_files <- list.files(bench_dir, pattern = "^results.*\\.json$", full.names = TRUE) + + if (length(json_files) == 0) { + stop("No results*.json files found in ", bench_dir) + } + + cat("Found", length(json_files), "benchmark result file(s):\n") + for (f in json_files) { + cat(" -", basename(f), "\n") + } + + all_rows <- list() + + for (json_file in json_files) { + data <- tryCatch({ + fromJSON(json_file, simplifyDataFrame = FALSE) + }, error = function(e) { + warning("Failed to parse ", json_file, ": ", e$message) + return(NULL) + }) + + if (is.null(data)) { + next + } + + benchmarks <- data$benchmarks + + if (is.null(benchmarks) || length(benchmarks) == 0) { + warning("No benchmarks found in ", json_file) + next + } + + for (bench in benchmarks) { + name <- bench$name %||% "" + if (!grepl("^BM_", name)) next + + row <- list( + name = name, + base = parse_base_name(name), + real_time_ms = time_to_ms(bench), + streamlines = bench$streamlines %||% NA, + length_profile = bench$length_profile %||% NA, + compression = bench$compression %||% NA, + group_case = bench$group_case %||% NA, + group_count = bench$group_count %||% NA, + dps = bench$dps %||% NA, + dpv = bench$dpv %||% NA, + write_ms = bench$write_ms %||% NA, + build_ms = bench$build_ms %||% NA, + file_bytes = bench$file_bytes %||% NA, + max_rss_kb = bench$max_rss_kb %||% NA, + query_p50_ms = bench$query_p50_ms %||% NA, + query_p95_ms = bench$query_p95_ms %||% NA, + shard_merge_ms = bench$shard_merge_ms %||% NA, + shard_processes = bench$shard_processes %||% NA, + source_file = basename(json_file) + ) + + all_rows[[length(all_rows) + 1]] <- row + } + } + + if (length(all_rows) == 0) { + stop("No valid benchmarks found in any JSON file") + } + + df <- bind_rows(all_rows) + + cat("\nLoaded", nrow(df), "benchmark results\n") + cat("Benchmark types found:\n") + for (base in unique(df$base)) { + count <- sum(df$base == base) + cat(" -", base, ":", count, "results\n") + } + + df +} + +#' Plot file sizes +plot_file_sizes <- function(df, out_dir) { + sub_df <- df %>% + filter(base == "BM_TrxFileSize_Float16") %>% + filter(!is.na(file_bytes), !is.na(streamlines)) + + if (nrow(sub_df) == 0) { + cat("No BM_TrxFileSize_Float16 results found, skipping file size plot\n") + return(invisible(NULL)) + } + + sub_df <- sub_df %>% + mutate( + file_mb = file_bytes / 1e6, + length_label = recode(as.character(length_profile), !!!LENGTH_LABELS), + compression_label = recode(as.character(compression), !!!COMPRESSION_LABELS), + dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)) + ) + + p <- ggplot(sub_df, aes(x = streamlines, y = file_mb, + color = length_label, linetype = dp_label)) + + geom_line(linewidth = 0.8) + + geom_point(size = 2) + + facet_wrap(~compression_label, ncol = 2) + + scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + scale_y_continuous(labels = label_number()) + + labs( + title = "TRX file size vs streamlines (float16 positions)", + x = "Streamlines", + y = "File size (MB)", + color = "Length profile", + linetype = "Data per point" + ) + + theme_bw() + + theme( + legend.position = "bottom", + legend.box = "vertical", + strip.background = element_rect(fill = "grey90") + ) + + out_path <- file.path(out_dir, "trx_size_vs_streamlines.png") + ggsave(out_path, p, width = 12, height = 7, dpi = 160) + cat("Saved:", out_path, "\n") +} + +#' Plot translate/write performance +plot_translate_write <- function(df, out_dir) { + sub_df <- df %>% + filter(base == "BM_TrxStream_TranslateWrite") %>% + filter(!is.na(real_time_ms), !is.na(streamlines)) + + if (nrow(sub_df) == 0) { + cat("No BM_TrxStream_TranslateWrite results found, skipping translate plots\n") + return(invisible(NULL)) + } + + sub_df <- sub_df %>% + mutate( + group_label = recode(as.character(group_case), !!!GROUP_LABELS), + dp_label = sprintf("dpv=%d, dps=%d", as.integer(dpv), as.integer(dps)), + rss_mb = max_rss_kb / 1024 + ) + + # Time plot + p_time <- ggplot(sub_df, aes(x = streamlines, y = real_time_ms, + color = dp_label)) + + geom_line(linewidth = 0.8) + + geom_point(size = 2) + + facet_wrap(~group_label, ncol = 2) + + scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + scale_y_continuous(labels = label_number()) + + labs( + title = "Translate + stream write throughput", + x = "Streamlines", + y = "Time (ms)", + color = "Data per point" + ) + + theme_bw() + + theme( + legend.position = "bottom", + strip.background = element_rect(fill = "grey90") + ) + + out_path <- file.path(out_dir, "trx_translate_write_time.png") + ggsave(out_path, p_time, width = 12, height = 5, dpi = 160) + cat("Saved:", out_path, "\n") + + # RSS plot + p_rss <- ggplot(sub_df, aes(x = streamlines, y = rss_mb, + color = dp_label)) + + geom_line(linewidth = 0.8) + + geom_point(size = 2) + + facet_wrap(~group_label, ncol = 2) + + scale_x_continuous(labels = label_number(scale = 1e-6, suffix = "M")) + + scale_y_continuous(labels = label_number()) + + labs( + title = "Translate + stream write memory usage", + x = "Streamlines", + y = "Max RSS (MB)", + color = "Data per point" + ) + + theme_bw() + + theme( + legend.position = "bottom", + strip.background = element_rect(fill = "grey90") + ) + + out_path <- file.path(out_dir, "trx_translate_write_rss.png") + ggsave(out_path, p_rss, width = 12, height = 5, dpi = 160) + cat("Saved:", out_path, "\n") +} + +#' Load query timings from JSONL file +load_query_timings <- function(jsonl_path) { + if (!file.exists(jsonl_path)) { + return(NULL) + } + + lines <- readLines(jsonl_path, warn = FALSE) + lines <- lines[nzchar(lines)] + + if (length(lines) == 0) { + return(NULL) + } + + rows <- lapply(lines, function(line) { + tryCatch({ + obj <- fromJSON(line, simplifyDataFrame = FALSE) + list( + streamlines = obj$streamlines %||% NA, + group_case = obj$group_case %||% NA, + group_count = obj$group_count %||% NA, + dps = obj$dps %||% NA, + dpv = obj$dpv %||% NA, + slab_thickness_mm = obj$slab_thickness_mm %||% NA, + timings_ms = I(list(unlist(obj$timings_ms))) + ) + }, error = function(e) NULL) + }) + + rows <- rows[!sapply(rows, is.null)] + + if (length(rows) == 0) { + return(NULL) + } + + bind_rows(rows) +} + +#' Plot query timing distributions +plot_query_timings <- function(bench_dir, out_dir, group_case = 0, dpv = 0, dps = 0) { + jsonl_path <- file.path(bench_dir, "query_timings.jsonl") + + df <- load_query_timings(jsonl_path) + + if (is.null(df) || nrow(df) == 0) { + cat("No query_timings.jsonl found or empty, skipping query timing plot\n") + return(invisible(NULL)) + } + + # Filter by specified conditions + df_filtered <- df %>% + filter( + group_case == !!group_case, + dpv == !!dpv, + dps == !!dps + ) + + if (nrow(df_filtered) == 0) { + cat("No query timings matching filters (group_case=", group_case, + ", dpv=", dpv, ", dps=", dps, "), skipping plot\n", sep = "") + return(invisible(NULL)) + } + + # Expand timings into long format + timing_data <- df_filtered %>% + mutate(streamlines_label = format(streamlines, big.mark = ",")) %>% + select(streamlines, streamlines_label, timings_ms) %>% + unnest(timings_ms) %>% + group_by(streamlines, streamlines_label) %>% + mutate(query_id = row_number()) %>% + ungroup() + + # Create boxplot + group_label <- GROUP_LABELS[as.character(group_case)] + + p <- ggplot(timing_data, aes(x = streamlines_label, y = timings_ms)) + + geom_boxplot(fill = "steelblue", alpha = 0.7, outlier.size = 0.5) + + labs( + title = sprintf("Slab query timings (%s, dpv=%d, dps=%d)", + group_label, dpv, dps), + x = "Streamlines", + y = "Per-slab query time (ms)" + ) + + theme_bw() + + theme( + axis.text.x = element_text(angle = 45, hjust = 1) + ) + + out_path <- file.path(out_dir, "trx_query_slab_timings.png") + ggsave(out_path, p, width = 10, height = 6, dpi = 160) + cat("Saved:", out_path, "\n") +} + +#' Main function +main <- function() { + args <- parse_args() + + # Create output directory + dir.create(args$out_dir, recursive = TRUE, showWarnings = FALSE) + + cat("\n=== TRX-CPP Benchmark Plotting ===\n\n") + cat("Benchmark directory:", args$bench_dir, "\n") + cat("Output directory:", args$out_dir, "\n\n") + + # Load benchmark results + df <- load_benchmarks(args$bench_dir) + + cat("\n--- Generating plots ---\n\n") + + # Generate plots + plot_file_sizes(df, args$out_dir) + plot_translate_write(df, args$out_dir) + plot_query_timings(args$bench_dir, args$out_dir, group_case = 0, dpv = 0, dps = 0) + + cat("\nDone! Plots saved to:", args$out_dir, "\n") +} + +# Define null-coalescing operator +`%||%` <- function(x, y) if (is.null(x)) y else x + +# Run main if executed as script +if (!interactive()) { + main() +} diff --git a/bench/plot_bench.py b/bench/plot_bench.py index 6b0aa3d..b346035 100644 --- a/bench/plot_bench.py +++ b/bench/plot_bench.py @@ -14,7 +14,6 @@ GROUP_LABELS = { 0: "no groups", 1: "bundle groups (80)", - 2: "connectome groups (4950)", } COMPRESSION_LABELS = {0: "store (no zip)", 1: "zip deflate"}