diff --git a/csrc/binding.cpp b/csrc/binding.cpp index d093127..493d5c1 100644 --- a/csrc/binding.cpp +++ b/csrc/binding.cpp @@ -19,11 +19,10 @@ namespace nb = nanobind; void do_bench(int result_fd, int input_fd, const std::string& kernel_qualname, const nb::object& test_generator, const nb::dict& test_kwargs, std::uintptr_t stream, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_sock_fd) { - ObfuscatedHexDigest signature; - std::mt19937 rng(std::random_device{}()); - signature.allocate(32, rng); - auto config = read_benchmark_parameters(input_fd, signature.data()); - auto mgr = make_benchmark_manager(result_fd, std::move(signature), config.Seed, discard, nvtx, landlock, mseal, supervisor_sock_fd); + std::vector signature_bytes(32); + auto config = read_benchmark_parameters(input_fd, signature_bytes.data()); + auto mgr = make_benchmark_manager(result_fd, signature_bytes, config.Seed, discard, nvtx, landlock, mseal, supervisor_sock_fd); + cleanse(signature_bytes.data(), 32); { nb::gil_scoped_release release; diff --git a/csrc/landlock.cpp b/csrc/landlock.cpp index acb21ff..916c82d 100644 --- a/csrc/landlock.cpp +++ b/csrc/landlock.cpp @@ -211,16 +211,6 @@ void setup_seccomp_filter(scmp_filter_ctx ctx) { check_seccomp(seccomp_rule_add(ctx, SCMP_ACT_ERRNO(EPERM), SCMP_SYS(prctl), 1, SCMP_A0(SCMP_CMP_EQ, PR_SET_PTRACER)), "block prctl(SET_PTRACER)"); - // TODO figure out what else we can and should block - /* - check_seccomp(seccomp_rule_add(ctx, SCMP_ACT_ERRNO(EPERM), SCMP_SYS(mprotect), 1, - SCMP_A2(SCMP_CMP_MASKED_EQ, PROT_WRITE, PROT_WRITE)), - "block mprotect+WRITE"); - - check_seccomp(seccomp_rule_add(ctx, SCMP_ACT_ERRNO(EPERM), SCMP_SYS(pkey_mprotect), 1, - SCMP_A2(SCMP_CMP_MASKED_EQ, PROT_WRITE, PROT_WRITE)), - "block pkey_mprotect+WRITE"); - */ } void install_seccomp_filter() { diff --git a/csrc/manager.cpp b/csrc/manager.cpp index 962d3e3..f0e70d6 100644 --- a/csrc/manager.cpp +++ b/csrc/manager.cpp @@ -19,6 +19,7 @@ #include #include #include +#include "protect.h" static constexpr std::size_t ArenaSize = 2 * 1024 * 1024; @@ -137,7 +138,7 @@ void BenchmarkManagerDeleter::operator()(BenchmarkManager* p) const noexcept { BenchmarkManagerPtr make_benchmark_manager( - int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, + int result_fd, const std::vector& signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket) { const std::size_t page_size = static_cast(getpagesize()); @@ -153,7 +154,7 @@ BenchmarkManagerPtr make_benchmark_manager( try { raw = new (mem) BenchmarkManager( static_cast(mem), alloc_size, - result_fd, std::move(signature), seed, + result_fd, signature, seed, discard, nvtx, landlock, mseal, supervisor_socket); } catch (...) { // If construction throws, release the mmap'd region before propagating. @@ -168,14 +169,14 @@ BenchmarkManagerPtr make_benchmark_manager( BenchmarkManager::BenchmarkManager(std::byte* arena, std::size_t arena_size, - int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, bool discard, + int result_fd, const std::vector& signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket) : mArena(arena), mResource(arena + sizeof(BenchmarkManager), arena_size - sizeof(BenchmarkManager), std::pmr::null_memory_resource()), - mSignature(std::move(signature)), + mSignature(&mResource), mSupervisorSock(supervisor_socket), mStartEvents(&mResource), mEndEvents(&mResource), @@ -195,11 +196,19 @@ BenchmarkManager::BenchmarkManager(std::byte* arena, std::size_t arena_size, throw std::runtime_error("Could not open output pipe"); } + if (signature.size() != 32) { + throw std::invalid_argument("Invalid signature length"); + } + mNVTXEnabled = nvtx; mLandlock = landlock; mSeal = mseal; mDiscardCache = discard; mSeed = seed; + std::random_device rd; + std::mt19937 rng(rd()); + mSignature.allocate(32, rng); + std::copy(signature.begin(), signature.end(), mSignature.data()); } @@ -337,18 +346,6 @@ void BenchmarkManager::install_protections() { install_seccomp_filter(); } -static inline std::uintptr_t page_mask() { - std::uintptr_t page_size = getpagesize(); - return ~(page_size - 1u); -} - -void protect_range(void* ptr, size_t size, int prot) { - std::uintptr_t start = reinterpret_cast(ptr) & page_mask(); - std::uintptr_t end = (reinterpret_cast(ptr) + size + getpagesize() - 1) & page_mask(); - if (mprotect(reinterpret_cast(start), end - start, prot) < 0) - throw std::system_error(errno, std::system_category(), "mprotect"); -} - static void setup_seccomp(int sock, bool install_notify, std::uintptr_t lo, std::uintptr_t hi) { if (sock < 0) return; @@ -394,48 +391,46 @@ nb::callable BenchmarkManager::initial_kernel_setup(double& time_estimate, const void* const cc_memory = mDeviceDummyMemory; const std::size_t l2_clear_size = mL2CacheSize; const bool discard_cache = mDiscardCache; - int device; - CUDA_CHECK(cudaGetDevice(&device)); - - nb::callable kernel; - std::exception_ptr thread_exception; nvtx_push("trigger-compile"); - protect_range(reinterpret_cast(lo), hi - lo, PROT_NONE); - - { - nb::gil_scoped_release release; - std::thread worker([&] { - try { - CUDA_CHECK(cudaSetDevice(device)); - setup_seccomp(sock, install_notify, lo, hi); - - nb::gil_scoped_acquire guard; - - kernel = kernel_from_qualname(qualname); - CUDA_CHECK(cudaDeviceSynchronize()); - kernel(*call_args); // trigger JIT compile - - time_estimate = run_warmup_loop(kernel, call_args, stream, - cc_memory, l2_clear_size, discard_cache, - warmup_seconds); - } catch (...) { - thread_exception = std::current_exception(); - } - }); - worker.join(); - } + PROTECT_RANGE(lo, hi-lo, PROT_NONE); + setup_seccomp(sock, install_notify, lo, hi); - protect_range(reinterpret_cast(lo), hi - lo, PROT_READ | PROT_WRITE); + nb::callable kernel = kernel_from_qualname(qualname); + CUDA_CHECK(cudaDeviceSynchronize()); + kernel(*call_args); // trigger JIT compile + + time_estimate = run_warmup_loop(kernel, call_args, stream, + cc_memory, l2_clear_size, discard_cache, + warmup_seconds); + + PROTECT_RANGE(lo, hi - lo, PROT_READ | PROT_WRITE); mSupervisorSock = -1; nvtx_pop(); - if (thread_exception) - std::rethrow_exception(thread_exception); - return kernel; } +void BenchmarkManager::randomize_before_test(int num_calls, std::mt19937& rng, cudaStream_t stream) { + // pick a random spot for the unsigned + // initialize the whole area with random junk; the error counter + // will be shifted by the initial value, so just writing zero + // won't result in passing the tests. + std::uniform_int_distribution dist(0, ArenaSize / sizeof(unsigned) - 1); + std::uniform_int_distribution noise_generator(0, std::numeric_limits::max()); + std::vector noise(ArenaSize / sizeof(unsigned)); + std::generate(noise.begin(), noise.end(), [&]() -> unsigned { return noise_generator(rng); }); + CUDA_CHECK(cudaMemcpyAsync(mDeviceErrorBase, noise.data(), noise.size() * sizeof(unsigned), cudaMemcpyHostToDevice, stream)); + std::ptrdiff_t offset = dist(rng); + mDeviceErrorCounter = mDeviceErrorBase + offset; + mErrorCountShift = noise.at(offset); + + // create a randomized order for running the tests + mTestOrder.resize(num_calls); + std::iota(mTestOrder.begin(), mTestOrder.end(), 1); + std::shuffle(mTestOrder.begin(), mTestOrder.end(), rng); +} + void BenchmarkManager::do_bench_py( const std::string& kernel_qualname, const std::vector& args, @@ -472,25 +467,13 @@ void BenchmarkManager::do_bench_py( "meaningful benchmark numbers: " + std::to_string(time_estimate)); } - // pick a random spot for the unsigned - // initialize the whole area with random junk; the error counter - // will be shifted by the initial value, so just writing zero - // won't result in passing the tests. std::random_device rd; std::mt19937 rng(rd()); - std::uniform_int_distribution dist(0, ArenaSize / sizeof(unsigned) - 1); - std::uniform_int_distribution noise_generator(0, std::numeric_limits::max()); - std::vector noise(ArenaSize / sizeof(unsigned)); - std::generate(noise.begin(), noise.end(), [&]() -> unsigned { return noise_generator(rng); }); - CUDA_CHECK(cudaMemcpyAsync(mDeviceErrorBase, noise.data(), noise.size() * sizeof(unsigned), cudaMemcpyHostToDevice, stream)); - std::ptrdiff_t offset = dist(rng); - mDeviceErrorCounter = mDeviceErrorBase + offset; - mErrorCountShift = noise.at(offset); - // create a randomized order for running the tests - mTestOrder.resize(actual_calls); - std::iota(mTestOrder.begin(), mTestOrder.end(), 1); - std::shuffle(mTestOrder.begin(), mTestOrder.end(), rng); + randomize_before_test(actual_calls, rng, stream); + // from this point on, even the benchmark thread won't write to the arena anymore + PROTECT_RANGE(mArena, BenchmarkManagerArenaSize, PROT_READ); + PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_NONE); // make the key fully inaccessible std::uniform_int_distribution check_seed_generator(0, 0xffffffff); @@ -540,12 +523,18 @@ void BenchmarkManager::send_report() { error_count -= mErrorCountShift; std::string message = build_result_message(mTestOrder, error_count, mMedianEventTime); + PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_READ); message = encrypt_message(mSignature.data(), 32, message); + PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_WRITE); + cleanse(mSignature.data(), 32); + PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_NONE); fwrite(message.data(), 1, message.size(), mOutputPipe); fflush(mOutputPipe); } void BenchmarkManager::clean_up() { + PROTECT_RANGE(mArena, BenchmarkManagerArenaSize, PROT_READ | PROT_WRITE); + for (auto& event : mStartEvents) CUDA_CHECK(cudaEventDestroy(event)); for (auto& event : mEndEvents) CUDA_CHECK(cudaEventDestroy(event)); mStartEvents.clear(); diff --git a/csrc/manager.h b/csrc/manager.h index c592d93..28bb8d1 100644 --- a/csrc/manager.h +++ b/csrc/manager.h @@ -42,7 +42,7 @@ struct BenchmarkManagerDeleter { using BenchmarkManagerPtr = std::unique_ptr; BenchmarkManagerPtr make_benchmark_manager( - int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, + int result_fd, const std::vector& signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket); @@ -53,13 +53,13 @@ class BenchmarkManager { void send_report(); void clean_up(); private: - friend BenchmarkManagerPtr make_benchmark_manager(int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket); + friend BenchmarkManagerPtr make_benchmark_manager(int result_fd, const std::vector& signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket); friend BenchmarkManagerDeleter; /// `arena` is the mmap region that owns all memory for this object and its vectors. /// The BenchmarkManager must have been placement-newed into the front of that region; /// the rest is used as a monotonic PMR arena for internal vectors. BenchmarkManager(std::byte* arena, std::size_t arena_size, - int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, + int result_fd, const std::vector& signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket); ~BenchmarkManager(); @@ -135,6 +135,7 @@ class BenchmarkManager { void setup_test_cases(const std::vector& args, const std::vector& expected, cudaStream_t stream); void install_protections(); + void randomize_before_test(int num_calls, std::mt19937& rng, cudaStream_t stream); nb::callable initial_kernel_setup(double& time_estimate, const std::string& qualname, const nb::tuple& call_args, cudaStream_t stream); [[nodiscard]] std::string build_result_message(const std::pmr::vector& test_order, unsigned error_count, float median_event_time) const; diff --git a/csrc/obfuscate.cpp b/csrc/obfuscate.cpp index 632c164..d529a83 100644 --- a/csrc/obfuscate.cpp +++ b/csrc/obfuscate.cpp @@ -18,48 +18,11 @@ #include #include -constexpr std::size_t PAGE_SIZE = 4096; +constexpr static std::size_t PAGE_SIZE = 4096; -ProtectablePage::ProtectablePage() { - void* page = mmap(nullptr, PAGE_SIZE, PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - if (page == MAP_FAILED) { - throw std::runtime_error("mmap failed"); - } - Page = slow_hash(page); -} - -ProtectablePage::~ProtectablePage() { - void* page = page_ptr(); - if (page) { - if (mprotect(page, PAGE_SIZE, PROT_READ | PROT_WRITE) != 0) { - std::perror("mprotect restore failed in ~ProtectablePage"); - } - if (munmap(page, PAGE_SIZE) != 0) { - std::perror("munmap failed in ~ProtectablePage"); - } - } -} - -ProtectablePage::ProtectablePage(ProtectablePage&& other) noexcept : Page(std::exchange(other.Page, slow_hash((void*)nullptr))){ -} - -void ProtectablePage::lock() { - void* page = page_ptr(); - if (mprotect(page, PAGE_SIZE, PROT_NONE) != 0) { - throw std::system_error(errno, std::generic_category(), "mprotect(PROT_NONE) failed"); - } -} - -void ProtectablePage::unlock() { - void* page = page_ptr(); - if (mprotect(page, PAGE_SIZE, PROT_READ) != 0) { - throw std::system_error(errno, std::generic_category(), "mprotect(PROT_READ) failed"); - } -} - -void* ProtectablePage::page_ptr() const { - return reinterpret_cast(slow_unhash(Page)); +ObfuscatedHexDigest::ObfuscatedHexDigest(std::pmr::monotonic_buffer_resource* mem) { + void* page = mem->allocate(PAGE_SIZE, PAGE_SIZE); + HashedPagePtr = slow_hash(reinterpret_cast(page)); } void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) { @@ -70,7 +33,7 @@ void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) { throw std::runtime_error("already allocated"); } - fill_random_hex(page_ptr(), PAGE_SIZE, rng); + fill_random_hex(reinterpret_cast(slow_unhash(HashedPagePtr)), PAGE_SIZE, rng); const std::uintptr_t max_offset = PAGE_SIZE - size - 1; std::uniform_int_distribution offset_dist(0, max_offset); @@ -79,8 +42,12 @@ void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) { HashedLen = slow_hash(size ^ offset); } +const void* ObfuscatedHexDigest::page_ptr() const { + return reinterpret_cast(slow_unhash(HashedPagePtr)); +} + char* ObfuscatedHexDigest::data() { - return reinterpret_cast(page_ptr()) + slow_unhash(HashedOffset); + return reinterpret_cast(slow_unhash(HashedPagePtr)) + slow_unhash(HashedOffset); } std::size_t ObfuscatedHexDigest::size() const { @@ -115,20 +82,15 @@ std::uintptr_t slow_unhash(std::uintptr_t p, int rounds) { return p; } -std::string encrypt_message(void* key, size_t keyLen, const std::string& plaintext) +void cleanse(void* ptr, size_t size) { + OPENSSL_cleanse(ptr, size); +} + +std::string encrypt_message(const char* key, size_t keyLen, const std::string& plaintext) { if (keyLen != 32) throw std::invalid_argument("encrypt_message: key must be exactly 32 bytes for AES-256"); - struct Cleanse - { - void* key; - size_t keyLen; - ~Cleanse() { - OPENSSL_cleanse(key, keyLen); - } - } cleanse_guard{key, keyLen}; - constexpr int NONCE_LEN = 12; constexpr int TAG_LEN = 16; @@ -142,9 +104,8 @@ std::string encrypt_message(void* key, size_t keyLen, const std::string& plainte struct CtxGuard { EVP_CIPHER_CTX* c; ~CtxGuard() { EVP_CIPHER_CTX_free(c); } } guard{ctx}; if (EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1 || - EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, NONCE_LEN, nullptr) != 1 || - EVP_EncryptInit_ex(ctx, nullptr, nullptr, - static_cast(key), nonce) != 1) + EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, NONCE_LEN, nullptr) != 1 || + EVP_EncryptInit_ex(ctx, nullptr, nullptr, reinterpret_cast(key), nonce) != 1) { throw std::runtime_error("encrypt_message: GCM init failed"); } @@ -173,4 +134,4 @@ std::string encrypt_message(void* key, size_t keyLen, const std::string& plainte packet.append(reinterpret_cast(ciphertext.data()), out_len + final_len); return packet; -} \ No newline at end of file +} diff --git a/csrc/obfuscate.h b/csrc/obfuscate.h index 1bc114d..bb7c7b8 100644 --- a/csrc/obfuscate.h +++ b/csrc/obfuscate.h @@ -5,39 +5,25 @@ #ifndef PYGPUBENCH_OBFUSCATE_H #define PYGPUBENCH_OBFUSCATE_H +#include #include #include +#include -// A single memory page that can be read-protected. -// This does not provide any actual defence against an attacker, -// because they could always just remove memory protection before -// access. But that in itself serves to increase the complexity of -// an attack. -class ProtectablePage { -public: - ProtectablePage(); - ~ProtectablePage(); - ProtectablePage(ProtectablePage&& other) noexcept; - - void lock(); - void unlock(); - - [[nodiscard]] void* page_ptr() const; - std::uintptr_t Page; -}; - -class ObfuscatedHexDigest : ProtectablePage { +class ObfuscatedHexDigest { public: - ObfuscatedHexDigest() = default; + ObfuscatedHexDigest(std::pmr::monotonic_buffer_resource* mem); void allocate(std::size_t size, std::mt19937& rng); + const void* page_ptr() const; char* data(); [[nodiscard]] std::size_t size() const; private: + std::uintptr_t HashedPagePtr = 0; std::uintptr_t HashedLen = 0; std::uintptr_t HashedOffset = 0; }; @@ -52,9 +38,10 @@ std::uintptr_t slow_hash(T* ptr, int rounds = 100'000) { return slow_hash(reinterpret_cast(ptr), rounds); } +void cleanse(void* ptr, size_t size); + // Encrypts `plaintext` with AES-256-GCM using `key` (must be exactly 32 bytes). // Returns a binary packet: [nonce (12)] [tag (16)] [ciphertext (N)]. -// key will be cleansed after use -std::string encrypt_message(void* key, size_t keyLen, const std::string& plaintext); +std::string encrypt_message(const char* key, size_t keyLen, const std::string& plaintext); #endif //PYGPUBENCH_OBFUSCATE_H \ No newline at end of file diff --git a/csrc/protect.h b/csrc/protect.h new file mode 100644 index 0000000..596bc20 --- /dev/null +++ b/csrc/protect.h @@ -0,0 +1,93 @@ +// Copyright (c) 2026 Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include +#include +#include +#include +#include +#include + +// Generates an inline mprotect syscall, rounding ptr..ptr+size out to page boundaries, +// and registers the exact address of the syscall instruction in the linker section +// __allowed_mprotect. This section is read at startup to tell the seccomp supervisor +// which instruction pointers are permitted to issue mprotect. +// +// The registration works as follows: the GNU assembler numeric label "1:" is placed +// directly on the syscall instruction in .text. A .quad 1f (a forward reference to +// that label) is emitted into __allowed_mprotect before switching back to .text. +// Because numeric labels are reusable, each expansion of this macro gets its own +// independent "1:" with no clashes, making it safe to use at multiple call sites. +// +// We need to insert to a writeable segment "aw", because the actual instruction +// addresses might only be resolved during load time. The whitelist of locations +// is sent to the supervisor process before any user code runs, so writability +// does not compromise security. +#define PROTECT_RANGE_MARKED(ptr, size, prot) \ + do { \ + const uintptr_t _page = static_cast(getpagesize()); \ + uintptr_t _start = reinterpret_cast(ptr) & ~(_page - 1); \ + uintptr_t _end = (reinterpret_cast(ptr) \ + + static_cast(size) + _page - 1) \ + & ~(_page - 1); \ + long _ret; \ + asm volatile ( \ + ".pushsection __allowed_mprotect, \"aw\"\n\t" \ + ".quad 1f\n\t" \ + ".popsection\n\t" \ + "1: syscall\n\t" \ + : "=a"(_ret) \ + : "0"(__NR_mprotect), \ + "D"(_start), \ + "S"(_end - _start), \ + "d"(static_cast(prot)) \ + : "rcx", "r11", "memory" \ + ); \ + if (_ret < 0) \ + throw std::system_error( \ + static_cast(-_ret), \ + std::system_category(), "mprotect"); \ + } while(0) + +// if prot==PROT_NONE, there's no seccomp filtering, so we don't need to +// register the callsite. +#define PROTECT_RANGE_NONE(ptr, size) \ + do { \ + const uintptr_t _page = static_cast(getpagesize()); \ + uintptr_t _start = reinterpret_cast(ptr) & ~(_page - 1); \ + uintptr_t _end = (reinterpret_cast(ptr) \ + + static_cast(size) + _page - 1) \ + & ~(_page - 1); \ + long _ret; \ + asm volatile ( \ + "syscall\n\t" \ + : "=a"(_ret) \ + : "0"(__NR_mprotect), \ + "D"(_start), \ + "S"(_end - _start), \ + "d"(static_cast(PROT_NONE)) \ + : "rcx", "r11", "memory" \ + ); \ + if (_ret < 0) \ + throw std::system_error( \ + static_cast(-_ret), \ + std::system_category(), "mprotect"); \ + } while(0) + + +#define PROTECT_RANGE(ptr, size, prot) \ + do { \ + constexpr int PROT = prot; \ + if constexpr (PROT == PROT_NONE) { \ + PROTECT_RANGE_NONE(ptr, size); \ + } else { \ + PROTECT_RANGE_MARKED(ptr, size, PROT); \ + } \ + } while(0) + +extern "C" { + extern uintptr_t __start___allowed_mprotect[]; // NOLINT(bugprone-reserved-identifier) + extern uintptr_t __stop___allowed_mprotect[]; // NOLINT(bugprone-reserved-identifier) +} diff --git a/csrc/protocol.h b/csrc/protocol.h new file mode 100644 index 0000000..739da88 --- /dev/null +++ b/csrc/protocol.h @@ -0,0 +1,25 @@ +// Copyright (c) 2026 Erik Schultheis +// SPDX-License-Identifier: Apache-2.0 +// + +// Wire protocol between the benchmark process and the seccomp supervisor. +// Both sides include this header; neither side should define these structs independently. + +#pragma once +#include + +constexpr int MAX_ALLOWED_SITES = 32; + +// Sent as regular data before the SCM_RIGHTS message. +// Followed immediately by n_allowed_sites * sizeof(uintptr_t) bytes, +// each being the exact address of an allowed mprotect syscall instruction. +struct SupervisorSetupMsg { + std::uintptr_t sensitive_lo; // protected arena start + std::uintptr_t sensitive_hi; // protected arena end (exclusive) + std::uint32_t n_allowed_sites; // number of entries that follow +}; + +// Each allowed site is a single pointer — the address of a `syscall` instruction +// registered via PROTECT_RANGE. The supervisor allows mprotect only when the +// instruction pointer is at site+2 (instruction past the syscall?) +using AllowedSite = std::uintptr_t; diff --git a/csrc/seccomp.cpp b/csrc/seccomp.cpp index 58c5a89..e3009c6 100644 --- a/csrc/seccomp.cpp +++ b/csrc/seccomp.cpp @@ -7,11 +7,27 @@ #include #include +#include "protect.h" +#include "protocol.h" + static inline void check_seccomp(int rc, const char* what) { if (rc < 0) throw std::system_error(-rc, std::generic_category(), what); } +static void send_all(int sock, const void* buf, size_t len) { + const auto* p = static_cast(buf); + while (len > 0) { + ssize_t n = send(sock, p, len, MSG_NOSIGNAL); + if (n < 0) { + if (errno == EINTR) continue; + throw std::system_error(errno, std::system_category(), "send"); + } + p += n; + len -= n; + } +} + // --------------------------------------------------------------------------- // Install a seccomp filter on the calling thread that sends all memory-range // syscalls to the supervisor via SCMP_ACT_NOTIFY. @@ -62,12 +78,23 @@ static int install_memory_notify_filter() { // Send the unotify fd + range to the supervisor over the socketpair. // --------------------------------------------------------------------------- -struct RangeMsg { uintptr_t lo, hi; }; +static void send_unotify_fd(int sock, int unotify_fd, + uintptr_t sensitive_lo, uintptr_t sensitive_hi) { + uint32_t n = __stop___allowed_mprotect - __start___allowed_mprotect; + if (n > MAX_ALLOWED_SITES) + throw std::runtime_error("too many allowed sites"); + + SupervisorSetupMsg hdr { sensitive_lo, sensitive_hi, n }; + + // Send header + site array as regular data + send_all(sock, &hdr, sizeof(hdr)); -static void send_unotify_fd(int sock, int unotify_fd, uintptr_t lo, uintptr_t hi) { - RangeMsg range = { lo, hi }; - struct iovec iov = { &range, sizeof(range) }; + size_t sites_sz = n * sizeof(AllowedSite); + send_all(sock, __start___allowed_mprotect, sites_sz); + // Send unotify_fd via SCM_RIGHTS + char dummy = 0; + struct iovec iov = { &dummy, 1 }; union { char buf[CMSG_SPACE(sizeof(int))]; struct cmsghdr align; @@ -86,8 +113,8 @@ static void send_unotify_fd(int sock, int unotify_fd, uintptr_t lo, uintptr_t hi cmsg->cmsg_len = CMSG_LEN(sizeof(int)); memcpy(CMSG_DATA(cmsg), &unotify_fd, sizeof(int)); - if (sendmsg(sock, &msg, 0) < 0) - throw std::system_error(errno, std::system_category(), "sendmsg"); + if (sendmsg(sock, &msg, MSG_NOSIGNAL) < 0) + throw std::system_error(errno, std::system_category(), "sendmsg unotify_fd"); } // --------------------------------------------------------------------------- diff --git a/csrc/supervisor.cpp b/csrc/supervisor.cpp index a224e94..385546e 100644 --- a/csrc/supervisor.cpp +++ b/csrc/supervisor.cpp @@ -3,19 +3,66 @@ #include #include #include +#include +#include #include #include #include #include #include +#include + +#include "protocol.h" +#include + +#ifdef DEBUG_SUPERVISOR +#define dbgprint(...) fprintf(stdout, __VA_ARGS__) +#else +#define dbgprint(...) +#endif + +struct Config { + uintptr_t sensitive_lo; + uintptr_t sensitive_hi; + std::vector allowed; +}; + +static void recv_all(int sock, void* buf, size_t len) { + auto* p = static_cast(buf); + while (len > 0) { + ssize_t n = recv(sock, p, len, MSG_WAITALL); + if (n < 0) { + if (errno == EINTR) continue; + throw std::system_error(errno, std::system_category(), "recv"); + } + if (n == 0) + throw std::runtime_error("supervisor: connection closed unexpectedly"); + p += n; + len -= n; + } +} + +static int recv_setup(int sock, Config& cfg) { + SupervisorSetupMsg setup; + recv_all(sock, &setup, sizeof(setup)); + + cfg.sensitive_lo = setup.sensitive_lo; + cfg.sensitive_hi = setup.sensitive_hi; + if (setup.n_allowed_sites > MAX_ALLOWED_SITES) + throw std::runtime_error("supervisor: too many allowed sites"); + + if (cfg.sensitive_lo >= cfg.sensitive_hi) + throw std::runtime_error("supervisor: invalid sensitive range"); -struct RangeMsg { uintptr_t lo, hi; }; + cfg.allowed.resize(setup.n_allowed_sites); + if (setup.n_allowed_sites > 0) { + size_t sites_sz = setup.n_allowed_sites * sizeof(AllowedSite); + recv_all(sock, cfg.allowed.data(), sites_sz); + } -static int recv_unotify_fd(int sock, uintptr_t& lo, uintptr_t& hi) { - RangeMsg range; - struct iovec iov = { &range, sizeof(range) }; + char dummy; + struct iovec iov = { &dummy, 1 }; - // Ancillary buffer for one fd union { char buf[CMSG_SPACE(sizeof(int))]; struct cmsghdr align; @@ -29,97 +76,105 @@ static int recv_unotify_fd(int sock, uintptr_t& lo, uintptr_t& hi) { ssize_t n = recvmsg(sock, &msg, MSG_CMSG_CLOEXEC); if (n < 0) { - perror("supervisor: recvmsg"); - return -1; - } - if (n != sizeof(range)) { - fprintf(stderr, "supervisor: short read: %zd\n", n); - return -1; + throw std::system_error(errno, std::system_category(), "recvmsg"); } struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - if (!cmsg || cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS - || cmsg->cmsg_len != CMSG_LEN(sizeof(int))) { - fprintf(stderr, "supervisor: missing or malformed SCM_RIGHTS\n"); - return -1; + if (!cmsg || cmsg->cmsg_level != SOL_SOCKET || + cmsg->cmsg_type != SCM_RIGHTS || + cmsg->cmsg_len != CMSG_LEN(sizeof(int))) { + throw std::runtime_error("supervisor: invalid SCM_RIGHTS"); } int unotify_fd; memcpy(&unotify_fd, CMSG_DATA(cmsg), sizeof(int)); - lo = range.lo; - hi = range.hi; return unotify_fd; } -// Returns true if [addr, addr+size) overlaps [lo, hi). -// Handles wraparound: if addr+size wraps, the range covers everything above addr, -// which necessarily overlaps any [lo, hi) where hi > lo. static bool overlaps(uintptr_t addr, uintptr_t size, uintptr_t lo, uintptr_t hi) { - // addr+size > lo AND addr < hi - // For wraparound case (addr+size < addr), the range wraps around the - // address space, so it definitely overlaps any non-empty [lo, hi). uintptr_t end = addr + size; bool wrapped = (end < addr); - return (wrapped || end > lo) && (addr < hi); + if (wrapped) { + return (addr < hi) || (end > lo); + } + return (end > lo) && (addr < hi); +} + +static bool ip_is_allowed(uintptr_t ip, const Config& cfg) { + for (AllowedSite site : cfg.allowed) { + if (ip == site + 2) + return true; + } + return false; } -// mprotect/mmap/munmap/madvise/remap_file_pages: args[0]=addr, args[1]=len. -// mremap: blocked unconditionally — MREMAP_FIXED moves the mapping to a new -// address chosen by the caller, making a safe overlap check impossible. -static bool handle_notification(int unotify_fd, uintptr_t lo, uintptr_t hi) { +static bool handle_notification(int unotify_fd, const Config& cfg) { struct seccomp_notif req = {}; struct seccomp_notif_resp resp = {}; if (ioctl(unotify_fd, SECCOMP_IOCTL_NOTIF_RECV, &req) < 0) { - if (errno == EINTR) return true; // interrupted, keep going - if (errno == ENODEV) return false; // tracee thread exited, we're done + if (errno == EINTR) return true; + if (errno == ENODEV) return false; perror("supervisor: SECCOMP_IOCTL_NOTIF_RECV"); return false; } resp.id = req.id; resp.flags = 0; - resp.error = 0; + resp.error = -EPERM; - // Check the notification is still valid before we act on it. - // This closes the race where the thread exits between RECV and SEND. - if (ioctl(unotify_fd, SECCOMP_IOCTL_NOTIF_ID_VALID, &req.id) < 0) { - // Thread is gone — keep looping in case other threads share this filter. + if (ioctl(unotify_fd, SECCOMP_IOCTL_NOTIF_ID_VALID, &req.id) < 0) return true; - } - // All remaining syscalls (mprotect, mmap, munmap, madvise): - // args[0]=addr, args[1]=len — check for overlap with protected range. - bool deny = overlaps(req.data.args[0], req.data.args[1], lo, hi); + uintptr_t ip = req.data.instruction_pointer; + uintptr_t addr = req.data.args[0]; + uintptr_t len = req.data.args[1]; + int prot = (int)req.data.args[2]; - if (deny) { - resp.error = -EPERM; - } else { + bool ip_ok = ip_is_allowed(ip, cfg); + bool contained = overlaps(addr, len, cfg.sensitive_lo, cfg.sensitive_hi); + bool prot_safe = prot == PROT_NONE; + + if (!contained) { + // touches other memory, this is fine + resp.error = 0; + resp.flags = SECCOMP_USER_NOTIF_FLAG_CONTINUE; + } else if ((ip_ok || prot_safe) && req.data.nr == SYS_mprotect) { + // touches our memory, but either makes it PROT_NONE or is from a whitelisted instruction + resp.error = 0; resp.flags = SECCOMP_USER_NOTIF_FLAG_CONTINUE; + dbgprint("Allowed mprotect from ip=0x%lx addr=[0x%lx,0x%lx) prot=%d\n",ip, addr, addr + len, prot); + } else { + dbgprint("supervisor: DENIED syscall %d from ip=0x%lx addr=[0x%lx,0x%lx) prot=%d " + "(ip_ok=%d contained=%d)\n", + req.data.nr, ip, addr, addr + len, prot, ip_ok, contained); } - if (ioctl(unotify_fd, SECCOMP_IOCTL_NOTIF_SEND, &resp) < 0) { - if (errno == ENOENT) return true; // thread gone between ID_VALID and SEND, fine - perror("supervisor: SECCOMP_IOCTL_NOTIF_SEND"); - } + if (ioctl(unotify_fd, SECCOMP_IOCTL_NOTIF_SEND, &resp) < 0) + if (errno != ENOENT) + perror("supervisor: SECCOMP_IOCTL_NOTIF_SEND"); + return true; } -// Entry point for the supervisor process. -// sock_fd is the supervisor's end of the socketpair, passed directly from Python. int supervisor_main(int sock_fd) { - // Die if our parent (the tracee process) dies. + if (prctl(PR_SET_DUMPABLE, 0) < 0) + throw std::system_error(errno, std::system_category(), "prctl(PR_SET_DUMPABLE)"); + prctl(PR_SET_PDEATHSIG, SIGTERM); - uintptr_t lo, hi; - int unotify_fd = recv_unotify_fd(sock_fd, lo, hi); + Config cfg; + int unotify_fd = recv_setup(sock_fd, cfg); close(sock_fd); if (unotify_fd < 0) return 1; - // Event loop: handle notifications until the tracee thread exits, - // at which point the unotify fd becomes invalid and NOTIF_RECV returns ENODEV. - while (handle_notification(unotify_fd, lo, hi)) + dbgprint("supervisor: sensitive=[0x%lx, 0x%lx), %zu allowed sites\n", + cfg.sensitive_lo, cfg.sensitive_hi, cfg.allowed.size()); + for (AllowedSite site : cfg.allowed) + dbgprint("supervisor: allowed site: 0x%lx\n", site); + + while (handle_notification(unotify_fd, cfg)) ; close(unotify_fd); diff --git a/exploits/submission_aes_key_recovery.py b/exploits/submission_aes_key_recovery.py index 26c8a81..12fe13f 100644 --- a/exploits/submission_aes_key_recovery.py +++ b/exploits/submission_aes_key_recovery.py @@ -2,7 +2,7 @@ EXPLOIT: AES Key Recovery via ObfuscatedHexDigest Arena Scan ============================================================= -Scans the BenchmarkManager arena (128MiB rw-p anonymous region) for the +Scans the BenchmarkManager arena for the ObfuscatedHexDigest struct. Page pointer is stored as slow_hash(ptr), offset and length as slow_hash(offset) and slow_hash(size ^ offset). Discriminator: slow_unhash(w0) must be page-aligned and canonical user-space (~1/4096 FP rate). @@ -32,7 +32,8 @@ _PROT_NONE = 0 _PROT_READ = 1 _PAGE_SIZE = 4096 -_ARENA_SIZE = 128 * 1024 * 1024 +#_ARENA_SIZE = 128 * 1024 * 1024 +# ARENA may be split into smaller mappings due to different protection settings for different parts _F_GETPIPE_SZ = 1025 def _mprotect(addr, prot): @@ -71,7 +72,7 @@ def slow_unhash(val: int, rounds: int = 100_000) -> int: _HEX_CHARS = frozenset(b'0123456789abcdef') -def _find_arena_candidates(): +def _find_arena_candidates() -> list[tuple]: candidates = [] with open('/proc/self/maps') as f: for line in f: @@ -81,13 +82,14 @@ def _find_arena_candidates(): start, end = (int(x, 16) for x in parts[0].split('-')) name = parts[5] if len(parts) >= 6 else '' perms = parts[1] - if perms == 'rw-p' and name == '' and (end - start) >= _ARENA_SIZE: - candidates.append(start) + if perms == 'r--p' and name == '': # and (end - start) >= _ARENA_SIZE: + candidates.append((start, end)) return candidates -def _recover_aes_key(arena_start: int): - SCAN_SIZE = 512 * 1024 +def _recover_aes_key(arena_start: int, arena_end: int): + SCAN_SIZE = min(512 * 1024, arena_end - arena_start) + print(f"[DBG] scanning arena+{arena_start:#x} for ObfuscatedHexDigest; len={SCAN_SIZE}") unhash_calls = 0 try: @@ -114,6 +116,7 @@ def _recover_aes_key(arena_start: int): continue offset = slow_unhash(w2) unhash_calls += 1 + print(f"[DBG] unhash_calls={unhash_calls} page_ptr={page_ptr:#x} offset={offset}") if offset > _PAGE_SIZE - 33: continue @@ -205,7 +208,7 @@ def kernel(output, inp): key_raw, key_str = None, None for arena in candidates: - key_raw, key_str = _recover_aes_key(arena) + key_raw, key_str = _recover_aes_key(*arena) if key_raw is not None: break