Skip to content

Commit 68dbb40

Browse files
authored
Instruction-pointer based memory protection #45
2 parents 23717f2 + 75b7332 commit 68dbb40

11 files changed

Lines changed: 360 additions & 230 deletions

csrc/binding.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ namespace nb = nanobind;
1919
void do_bench(int result_fd, int input_fd, const std::string& kernel_qualname, const nb::object& test_generator,
2020
const nb::dict& test_kwargs, std::uintptr_t stream, bool discard, bool nvtx, bool landlock, bool mseal,
2121
int supervisor_sock_fd) {
22-
ObfuscatedHexDigest signature;
23-
std::mt19937 rng(std::random_device{}());
24-
signature.allocate(32, rng);
25-
auto config = read_benchmark_parameters(input_fd, signature.data());
26-
auto mgr = make_benchmark_manager(result_fd, std::move(signature), config.Seed, discard, nvtx, landlock, mseal, supervisor_sock_fd);
22+
std::vector<char> signature_bytes(32);
23+
auto config = read_benchmark_parameters(input_fd, signature_bytes.data());
24+
auto mgr = make_benchmark_manager(result_fd, signature_bytes, config.Seed, discard, nvtx, landlock, mseal, supervisor_sock_fd);
25+
cleanse(signature_bytes.data(), 32);
2726

2827
{
2928
nb::gil_scoped_release release;

csrc/landlock.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,6 @@ void setup_seccomp_filter(scmp_filter_ctx ctx) {
211211
check_seccomp(seccomp_rule_add(ctx, SCMP_ACT_ERRNO(EPERM), SCMP_SYS(prctl), 1,
212212
SCMP_A0(SCMP_CMP_EQ, PR_SET_PTRACER)),
213213
"block prctl(SET_PTRACER)");
214-
// TODO figure out what else we can and should block
215-
/*
216-
check_seccomp(seccomp_rule_add(ctx, SCMP_ACT_ERRNO(EPERM), SCMP_SYS(mprotect), 1,
217-
SCMP_A2(SCMP_CMP_MASKED_EQ, PROT_WRITE, PROT_WRITE)),
218-
"block mprotect+WRITE");
219-
220-
check_seccomp(seccomp_rule_add(ctx, SCMP_ACT_ERRNO(EPERM), SCMP_SYS(pkey_mprotect), 1,
221-
SCMP_A2(SCMP_CMP_MASKED_EQ, PROT_WRITE, PROT_WRITE)),
222-
"block pkey_mprotect+WRITE");
223-
*/
224214
}
225215

226216
void install_seccomp_filter() {

csrc/manager.cpp

Lines changed: 54 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <nanobind/stl/string.h>
2020
#include <sys/mman.h>
2121
#include <unistd.h>
22+
#include "protect.h"
2223

2324
static constexpr std::size_t ArenaSize = 2 * 1024 * 1024;
2425

@@ -137,7 +138,7 @@ void BenchmarkManagerDeleter::operator()(BenchmarkManager* p) const noexcept {
137138

138139

139140
BenchmarkManagerPtr make_benchmark_manager(
140-
int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed,
141+
int result_fd, const std::vector<char>& signature, std::uint64_t seed,
141142
bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket)
142143
{
143144
const std::size_t page_size = static_cast<std::size_t>(getpagesize());
@@ -153,7 +154,7 @@ BenchmarkManagerPtr make_benchmark_manager(
153154
try {
154155
raw = new (mem) BenchmarkManager(
155156
static_cast<std::byte*>(mem), alloc_size,
156-
result_fd, std::move(signature), seed,
157+
result_fd, signature, seed,
157158
discard, nvtx, landlock, mseal, supervisor_socket);
158159
} catch (...) {
159160
// If construction throws, release the mmap'd region before propagating.
@@ -168,14 +169,14 @@ BenchmarkManagerPtr make_benchmark_manager(
168169

169170

170171
BenchmarkManager::BenchmarkManager(std::byte* arena, std::size_t arena_size,
171-
int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, bool discard,
172+
int result_fd, const std::vector<char>& signature, std::uint64_t seed, bool discard,
172173
bool nvtx, bool landlock, bool mseal, int supervisor_socket)
173174
: mArena(arena),
174175
mResource(arena + sizeof(BenchmarkManager),
175176
arena_size - sizeof(BenchmarkManager),
176177
std::pmr::null_memory_resource()),
177178

178-
mSignature(std::move(signature)),
179+
mSignature(&mResource),
179180
mSupervisorSock(supervisor_socket),
180181
mStartEvents(&mResource),
181182
mEndEvents(&mResource),
@@ -195,11 +196,19 @@ BenchmarkManager::BenchmarkManager(std::byte* arena, std::size_t arena_size,
195196
throw std::runtime_error("Could not open output pipe");
196197
}
197198

199+
if (signature.size() != 32) {
200+
throw std::invalid_argument("Invalid signature length");
201+
}
202+
198203
mNVTXEnabled = nvtx;
199204
mLandlock = landlock;
200205
mSeal = mseal;
201206
mDiscardCache = discard;
202207
mSeed = seed;
208+
std::random_device rd;
209+
std::mt19937 rng(rd());
210+
mSignature.allocate(32, rng);
211+
std::copy(signature.begin(), signature.end(), mSignature.data());
203212
}
204213

205214

@@ -337,18 +346,6 @@ void BenchmarkManager::install_protections() {
337346
install_seccomp_filter();
338347
}
339348

340-
static inline std::uintptr_t page_mask() {
341-
std::uintptr_t page_size = getpagesize();
342-
return ~(page_size - 1u);
343-
}
344-
345-
void protect_range(void* ptr, size_t size, int prot) {
346-
std::uintptr_t start = reinterpret_cast<std::uintptr_t>(ptr) & page_mask();
347-
std::uintptr_t end = (reinterpret_cast<std::uintptr_t>(ptr) + size + getpagesize() - 1) & page_mask();
348-
if (mprotect(reinterpret_cast<void*>(start), end - start, prot) < 0)
349-
throw std::system_error(errno, std::system_category(), "mprotect");
350-
}
351-
352349
static void setup_seccomp(int sock, bool install_notify, std::uintptr_t lo, std::uintptr_t hi) {
353350
if (sock < 0)
354351
return;
@@ -394,48 +391,46 @@ nb::callable BenchmarkManager::initial_kernel_setup(double& time_estimate, const
394391
void* const cc_memory = mDeviceDummyMemory;
395392
const std::size_t l2_clear_size = mL2CacheSize;
396393
const bool discard_cache = mDiscardCache;
397-
int device;
398-
CUDA_CHECK(cudaGetDevice(&device));
399-
400-
nb::callable kernel;
401-
std::exception_ptr thread_exception;
402394

403395
nvtx_push("trigger-compile");
404-
protect_range(reinterpret_cast<void*>(lo), hi - lo, PROT_NONE);
405-
406-
{
407-
nb::gil_scoped_release release;
408-
std::thread worker([&] {
409-
try {
410-
CUDA_CHECK(cudaSetDevice(device));
411-
setup_seccomp(sock, install_notify, lo, hi);
412-
413-
nb::gil_scoped_acquire guard;
414-
415-
kernel = kernel_from_qualname(qualname);
416-
CUDA_CHECK(cudaDeviceSynchronize());
417-
kernel(*call_args); // trigger JIT compile
418-
419-
time_estimate = run_warmup_loop(kernel, call_args, stream,
420-
cc_memory, l2_clear_size, discard_cache,
421-
warmup_seconds);
422-
} catch (...) {
423-
thread_exception = std::current_exception();
424-
}
425-
});
426-
worker.join();
427-
}
396+
PROTECT_RANGE(lo, hi-lo, PROT_NONE);
397+
setup_seccomp(sock, install_notify, lo, hi);
428398

429-
protect_range(reinterpret_cast<void*>(lo), hi - lo, PROT_READ | PROT_WRITE);
399+
nb::callable kernel = kernel_from_qualname(qualname);
400+
CUDA_CHECK(cudaDeviceSynchronize());
401+
kernel(*call_args); // trigger JIT compile
402+
403+
time_estimate = run_warmup_loop(kernel, call_args, stream,
404+
cc_memory, l2_clear_size, discard_cache,
405+
warmup_seconds);
406+
407+
PROTECT_RANGE(lo, hi - lo, PROT_READ | PROT_WRITE);
430408
mSupervisorSock = -1;
431409
nvtx_pop();
432410

433-
if (thread_exception)
434-
std::rethrow_exception(thread_exception);
435-
436411
return kernel;
437412
}
438413

414+
void BenchmarkManager::randomize_before_test(int num_calls, std::mt19937& rng, cudaStream_t stream) {
415+
// pick a random spot for the unsigned
416+
// initialize the whole area with random junk; the error counter
417+
// will be shifted by the initial value, so just writing zero
418+
// won't result in passing the tests.
419+
std::uniform_int_distribution<std::ptrdiff_t> dist(0, ArenaSize / sizeof(unsigned) - 1);
420+
std::uniform_int_distribution<unsigned> noise_generator(0, std::numeric_limits<unsigned>::max());
421+
std::vector<unsigned> noise(ArenaSize / sizeof(unsigned));
422+
std::generate(noise.begin(), noise.end(), [&]() -> unsigned { return noise_generator(rng); });
423+
CUDA_CHECK(cudaMemcpyAsync(mDeviceErrorBase, noise.data(), noise.size() * sizeof(unsigned), cudaMemcpyHostToDevice, stream));
424+
std::ptrdiff_t offset = dist(rng);
425+
mDeviceErrorCounter = mDeviceErrorBase + offset;
426+
mErrorCountShift = noise.at(offset);
427+
428+
// create a randomized order for running the tests
429+
mTestOrder.resize(num_calls);
430+
std::iota(mTestOrder.begin(), mTestOrder.end(), 1);
431+
std::shuffle(mTestOrder.begin(), mTestOrder.end(), rng);
432+
}
433+
439434
void BenchmarkManager::do_bench_py(
440435
const std::string& kernel_qualname,
441436
const std::vector<nb::tuple>& args,
@@ -472,25 +467,13 @@ void BenchmarkManager::do_bench_py(
472467
"meaningful benchmark numbers: " + std::to_string(time_estimate));
473468
}
474469

475-
// pick a random spot for the unsigned
476-
// initialize the whole area with random junk; the error counter
477-
// will be shifted by the initial value, so just writing zero
478-
// won't result in passing the tests.
479470
std::random_device rd;
480471
std::mt19937 rng(rd());
481-
std::uniform_int_distribution<std::ptrdiff_t> dist(0, ArenaSize / sizeof(unsigned) - 1);
482-
std::uniform_int_distribution<unsigned> noise_generator(0, std::numeric_limits<unsigned>::max());
483-
std::vector<unsigned> noise(ArenaSize / sizeof(unsigned));
484-
std::generate(noise.begin(), noise.end(), [&]() -> unsigned { return noise_generator(rng); });
485-
CUDA_CHECK(cudaMemcpyAsync(mDeviceErrorBase, noise.data(), noise.size() * sizeof(unsigned), cudaMemcpyHostToDevice, stream));
486-
std::ptrdiff_t offset = dist(rng);
487-
mDeviceErrorCounter = mDeviceErrorBase + offset;
488-
mErrorCountShift = noise.at(offset);
489472

490-
// create a randomized order for running the tests
491-
mTestOrder.resize(actual_calls);
492-
std::iota(mTestOrder.begin(), mTestOrder.end(), 1);
493-
std::shuffle(mTestOrder.begin(), mTestOrder.end(), rng);
473+
randomize_before_test(actual_calls, rng, stream);
474+
// from this point on, even the benchmark thread won't write to the arena anymore
475+
PROTECT_RANGE(mArena, BenchmarkManagerArenaSize, PROT_READ);
476+
PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_NONE); // make the key fully inaccessible
494477

495478
std::uniform_int_distribution<unsigned> check_seed_generator(0, 0xffffffff);
496479

@@ -540,12 +523,18 @@ void BenchmarkManager::send_report() {
540523
error_count -= mErrorCountShift;
541524

542525
std::string message = build_result_message(mTestOrder, error_count, mMedianEventTime);
526+
PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_READ);
543527
message = encrypt_message(mSignature.data(), 32, message);
528+
PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_WRITE);
529+
cleanse(mSignature.data(), 32);
530+
PROTECT_RANGE(mSignature.page_ptr(), 4096, PROT_NONE);
544531
fwrite(message.data(), 1, message.size(), mOutputPipe);
545532
fflush(mOutputPipe);
546533
}
547534

548535
void BenchmarkManager::clean_up() {
536+
PROTECT_RANGE(mArena, BenchmarkManagerArenaSize, PROT_READ | PROT_WRITE);
537+
549538
for (auto& event : mStartEvents) CUDA_CHECK(cudaEventDestroy(event));
550539
for (auto& event : mEndEvents) CUDA_CHECK(cudaEventDestroy(event));
551540
mStartEvents.clear();

csrc/manager.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct BenchmarkManagerDeleter {
4242
using BenchmarkManagerPtr = std::unique_ptr<BenchmarkManager, BenchmarkManagerDeleter>;
4343

4444
BenchmarkManagerPtr make_benchmark_manager(
45-
int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed,
45+
int result_fd, const std::vector<char>& signature, std::uint64_t seed,
4646
bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket);
4747

4848

@@ -53,13 +53,13 @@ class BenchmarkManager {
5353
void send_report();
5454
void clean_up();
5555
private:
56-
friend BenchmarkManagerPtr make_benchmark_manager(int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket);
56+
friend BenchmarkManagerPtr make_benchmark_manager(int result_fd, const std::vector<char>& signature, std::uint64_t seed, bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket);
5757
friend BenchmarkManagerDeleter;
5858
/// `arena` is the mmap region that owns all memory for this object and its vectors.
5959
/// The BenchmarkManager must have been placement-newed into the front of that region;
6060
/// the rest is used as a monotonic PMR arena for internal vectors.
6161
BenchmarkManager(std::byte* arena, std::size_t arena_size,
62-
int result_fd, ObfuscatedHexDigest signature, std::uint64_t seed,
62+
int result_fd, const std::vector<char>& signature, std::uint64_t seed,
6363
bool discard, bool nvtx, bool landlock, bool mseal, int supervisor_socket);
6464
~BenchmarkManager();
6565

@@ -135,6 +135,7 @@ class BenchmarkManager {
135135
void setup_test_cases(const std::vector<nb::tuple>& args, const std::vector<nb::tuple>& expected, cudaStream_t stream);
136136

137137
void install_protections();
138+
void randomize_before_test(int num_calls, std::mt19937& rng, cudaStream_t stream);
138139
nb::callable initial_kernel_setup(double& time_estimate, const std::string& qualname, const nb::tuple& call_args, cudaStream_t stream);
139140

140141
[[nodiscard]] std::string build_result_message(const std::pmr::vector<int>& test_order, unsigned error_count, float median_event_time) const;

csrc/obfuscate.cpp

Lines changed: 18 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,48 +18,11 @@
1818
#include <openssl/evp.h>
1919
#include <openssl/rand.h>
2020

21-
constexpr std::size_t PAGE_SIZE = 4096;
21+
constexpr static std::size_t PAGE_SIZE = 4096;
2222

23-
ProtectablePage::ProtectablePage() {
24-
void* page = mmap(nullptr, PAGE_SIZE, PROT_READ | PROT_WRITE,
25-
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
26-
if (page == MAP_FAILED) {
27-
throw std::runtime_error("mmap failed");
28-
}
29-
Page = slow_hash(page);
30-
}
31-
32-
ProtectablePage::~ProtectablePage() {
33-
void* page = page_ptr();
34-
if (page) {
35-
if (mprotect(page, PAGE_SIZE, PROT_READ | PROT_WRITE) != 0) {
36-
std::perror("mprotect restore failed in ~ProtectablePage");
37-
}
38-
if (munmap(page, PAGE_SIZE) != 0) {
39-
std::perror("munmap failed in ~ProtectablePage");
40-
}
41-
}
42-
}
43-
44-
ProtectablePage::ProtectablePage(ProtectablePage&& other) noexcept : Page(std::exchange(other.Page, slow_hash((void*)nullptr))){
45-
}
46-
47-
void ProtectablePage::lock() {
48-
void* page = page_ptr();
49-
if (mprotect(page, PAGE_SIZE, PROT_NONE) != 0) {
50-
throw std::system_error(errno, std::generic_category(), "mprotect(PROT_NONE) failed");
51-
}
52-
}
53-
54-
void ProtectablePage::unlock() {
55-
void* page = page_ptr();
56-
if (mprotect(page, PAGE_SIZE, PROT_READ) != 0) {
57-
throw std::system_error(errno, std::generic_category(), "mprotect(PROT_READ) failed");
58-
}
59-
}
60-
61-
void* ProtectablePage::page_ptr() const {
62-
return reinterpret_cast<void*>(slow_unhash(Page));
23+
ObfuscatedHexDigest::ObfuscatedHexDigest(std::pmr::monotonic_buffer_resource* mem) {
24+
void* page = mem->allocate(PAGE_SIZE, PAGE_SIZE);
25+
HashedPagePtr = slow_hash(reinterpret_cast<std::uintptr_t>(page));
6326
}
6427

6528
void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) {
@@ -70,7 +33,7 @@ void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) {
7033
throw std::runtime_error("already allocated");
7134
}
7235

73-
fill_random_hex(page_ptr(), PAGE_SIZE, rng);
36+
fill_random_hex(reinterpret_cast<void*>(slow_unhash(HashedPagePtr)), PAGE_SIZE, rng);
7437
const std::uintptr_t max_offset = PAGE_SIZE - size - 1;
7538
std::uniform_int_distribution<std::uintptr_t> offset_dist(0, max_offset);
7639

@@ -79,8 +42,12 @@ void ObfuscatedHexDigest::allocate(std::size_t size, std::mt19937& rng) {
7942
HashedLen = slow_hash(size ^ offset);
8043
}
8144

45+
const void* ObfuscatedHexDigest::page_ptr() const {
46+
return reinterpret_cast<const void*>(slow_unhash(HashedPagePtr));
47+
}
48+
8249
char* ObfuscatedHexDigest::data() {
83-
return reinterpret_cast<char*>(page_ptr()) + slow_unhash(HashedOffset);
50+
return reinterpret_cast<char*>(slow_unhash(HashedPagePtr)) + slow_unhash(HashedOffset);
8451
}
8552

8653
std::size_t ObfuscatedHexDigest::size() const {
@@ -115,20 +82,15 @@ std::uintptr_t slow_unhash(std::uintptr_t p, int rounds) {
11582
return p;
11683
}
11784

118-
std::string encrypt_message(void* key, size_t keyLen, const std::string& plaintext)
85+
void cleanse(void* ptr, size_t size) {
86+
OPENSSL_cleanse(ptr, size);
87+
}
88+
89+
std::string encrypt_message(const char* key, size_t keyLen, const std::string& plaintext)
11990
{
12091
if (keyLen != 32)
12192
throw std::invalid_argument("encrypt_message: key must be exactly 32 bytes for AES-256");
12293

123-
struct Cleanse
124-
{
125-
void* key;
126-
size_t keyLen;
127-
~Cleanse() {
128-
OPENSSL_cleanse(key, keyLen);
129-
}
130-
} cleanse_guard{key, keyLen};
131-
13294
constexpr int NONCE_LEN = 12;
13395
constexpr int TAG_LEN = 16;
13496

@@ -142,9 +104,8 @@ std::string encrypt_message(void* key, size_t keyLen, const std::string& plainte
142104
struct CtxGuard { EVP_CIPHER_CTX* c; ~CtxGuard() { EVP_CIPHER_CTX_free(c); } } guard{ctx};
143105

144106
if (EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1 ||
145-
EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, NONCE_LEN, nullptr) != 1 ||
146-
EVP_EncryptInit_ex(ctx, nullptr, nullptr,
147-
static_cast<const unsigned char*>(key), nonce) != 1)
107+
EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, NONCE_LEN, nullptr) != 1 ||
108+
EVP_EncryptInit_ex(ctx, nullptr, nullptr, reinterpret_cast<const unsigned char*>(key), nonce) != 1)
148109
{
149110
throw std::runtime_error("encrypt_message: GCM init failed");
150111
}
@@ -173,4 +134,4 @@ std::string encrypt_message(void* key, size_t keyLen, const std::string& plainte
173134
packet.append(reinterpret_cast<char*>(ciphertext.data()), out_len + final_len);
174135

175136
return packet;
176-
}
137+
}

0 commit comments

Comments
 (0)