diff --git a/packages/orchestrator/internal/sandbox/block/device.go b/packages/orchestrator/internal/sandbox/block/device.go index fd6851b691..0930602534 100644 --- a/packages/orchestrator/internal/sandbox/block/device.go +++ b/packages/orchestrator/internal/sandbox/block/device.go @@ -16,6 +16,7 @@ func (BytesNotAvailableError) Error() string { type Slicer interface { Slice(ctx context.Context, off, length int64) ([]byte, error) + BlockSize() int64 } type ReadonlyDevice interface { diff --git a/packages/orchestrator/internal/sandbox/block/tracker.go b/packages/orchestrator/internal/sandbox/block/tracker.go index b0caf19411..b9daaed02a 100644 --- a/packages/orchestrator/internal/sandbox/block/tracker.go +++ b/packages/orchestrator/internal/sandbox/block/tracker.go @@ -1,66 +1,83 @@ package block import ( - "context" - "fmt" + "iter" "sync" - "sync/atomic" "github.com/bits-and-blooms/bitset" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) -type TrackedSliceDevice struct { - data ReadonlyDevice - blockSize int64 +type Tracker struct { + b *bitset.BitSet + mu sync.RWMutex - nilTracking atomic.Bool - dirty *bitset.BitSet - dirtyMu sync.Mutex - empty []byte + blockSize int64 } -func NewTrackedSliceDevice(blockSize int64, device ReadonlyDevice) (*TrackedSliceDevice, error) { - return &TrackedSliceDevice{ - data: device, - empty: make([]byte, blockSize), +func NewTracker(blockSize int64) *Tracker { + return &Tracker{ + // The bitset resizes automatically based on the maximum set bit. + b: bitset.New(0), blockSize: blockSize, - }, nil + } } -func (t *TrackedSliceDevice) Disable() error { - size, err := t.data.Size() - if err != nil { - return fmt.Errorf("failed to get device size: %w", err) +func NewTrackerFromBitset(b *bitset.BitSet, blockSize int64) *Tracker { + return &Tracker{ + b: b, + blockSize: blockSize, } +} + +func (t *Tracker) Has(off int64) bool { + t.mu.RLock() + defer t.mu.RUnlock() - t.dirty = bitset.New(uint(header.TotalBlocks(size, t.blockSize))) - // We are starting with all being dirty. - t.dirty.FlipRange(0, t.dirty.Len()) + return t.b.Test(uint(header.BlockIdx(off, t.blockSize))) +} - t.nilTracking.Store(true) +func (t *Tracker) Add(off int64) { + t.mu.Lock() + defer t.mu.Unlock() - return nil + t.b.Set(uint(header.BlockIdx(off, t.blockSize))) } -func (t *TrackedSliceDevice) Slice(ctx context.Context, off int64, length int64) ([]byte, error) { - if t.nilTracking.Load() { - t.dirtyMu.Lock() - t.dirty.Clear(uint(header.BlockIdx(off, t.blockSize))) - t.dirtyMu.Unlock() +func (t *Tracker) Reset() { + t.mu.Lock() + defer t.mu.Unlock() - return t.empty, nil - } + t.b.ClearAll() +} + +// BitSet returns a clone of the bitset and the block size. +func (t *Tracker) BitSet() *bitset.BitSet { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.b.Clone() +} - return t.data.Slice(ctx, off, length) +func (t *Tracker) BlockSize() int64 { + return t.blockSize } -// Return which bytes were not read since Disable. -// This effectively returns the bytes that have been requested after paused vm and are not dirty. -func (t *TrackedSliceDevice) Dirty() *bitset.BitSet { - t.dirtyMu.Lock() - defer t.dirtyMu.Unlock() +func (t *Tracker) Clone() *Tracker { + return &Tracker{ + b: t.BitSet(), + blockSize: t.BlockSize(), + } +} + +func (t *Tracker) Offsets() iter.Seq[int64] { + return bitsetOffsets(t.BitSet(), t.BlockSize()) +} - return t.dirty.Clone() +func bitsetOffsets(b *bitset.BitSet, blockSize int64) iter.Seq[int64] { + return utils.TransformTo(b.EachSet(), func(idx uint) int64 { + return header.BlockOffset(int64(idx), blockSize) + }) } diff --git a/packages/orchestrator/internal/sandbox/block/tracker_test.go b/packages/orchestrator/internal/sandbox/block/tracker_test.go new file mode 100644 index 0000000000..03204df966 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/block/tracker_test.go @@ -0,0 +1,167 @@ +package block + +import ( + "maps" + "math/rand" + "slices" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTracker_AddAndHas(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offset := int64(pageSize * 4) + + // Initially should not be marked + assert.False(t, tr.Has(offset), "Expected offset %d not to be marked initially", offset) + + // After adding, should be marked + tr.Add(offset) + assert.True(t, tr.Has(offset), "Expected offset %d to be marked after Add", offset) + + // Other offsets should not be marked + otherOffsets := []int64{ + 0, pageSize, 2 * pageSize, 3 * pageSize, 5 * pageSize, 10 * pageSize, + } + for _, other := range otherOffsets { + if other == offset { + continue + } + assert.False(t, tr.Has(other), "Did not expect offset %d to be marked (only %d should be marked)", other, offset) + } +} + +func TestTracker_Reset(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offset := int64(pageSize * 4) + + // Add offset and verify it's marked + tr.Add(offset) + assert.True(t, tr.Has(offset), "Expected offset %d to be marked after Add", offset) + + // After reset, should not be marked + tr.Reset() + assert.False(t, tr.Has(offset), "Expected offset %d to be cleared after Reset", offset) + + // Offsets that were never set should also remain unset + otherOffsets := []int64{0, pageSize, 2 * pageSize, pageSize * 10} + for _, other := range otherOffsets { + assert.False(t, tr.Has(other), "Expected offset %d to not be marked after Reset", other) + } +} + +func TestTracker_MultipleOffsets(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize} + + // Add multiple offsets + for _, o := range offsets { + tr.Add(o) + } + + // Verify all offsets are marked + for _, o := range offsets { + assert.True(t, tr.Has(o), "Expected offset %d to be marked", o) + } + + // Check offsets in between added offsets are not set + // (Offsets that aren't inside any marked block should not be marked) + nonSetOffsets := []int64{ + 3 * pageSize, + 4 * pageSize, + 5 * pageSize, + 6 * pageSize, + 7 * pageSize, + 8 * pageSize, + 9 * pageSize, + 11 * pageSize, + } + for _, off := range nonSetOffsets { + assert.False(t, tr.Has(off), "Expected offset %d to not be marked (only explicit blocks added)", off) + } +} + +func TestTracker_ResetClearsAll(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + offsets := []int64{0, pageSize, 2 * pageSize, 10 * pageSize} + + // Add multiple offsets + for _, o := range offsets { + tr.Add(o) + } + + // Reset should clear all + tr.Reset() + + // Verify all offsets are cleared + for _, o := range offsets { + assert.False(t, tr.Has(o), "Expected offset %d to be cleared after Reset", o) + } + // Check unrelated offsets also not marked + moreOffsets := []int64{3 * pageSize, 7 * pageSize, 100, 4095} + for _, o := range moreOffsets { + assert.False(t, tr.Has(o), "Expected offset %d to not be marked after Reset", o) + } +} + +func TestTracker_MisalignedOffset(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + // Test with misaligned offset + misalignedOffset := int64(123) + tr.Add(misalignedOffset) + + // Should be set for the block containing the offset—that is, block 0 (0..4095) + assert.True(t, tr.Has(misalignedOffset), "Expected misaligned offset %d to be marked (should mark its containing block)", misalignedOffset) + + // Now check that any offset in the same block is also considered marked + anotherOffsetInSameBlock := int64(1000) + assert.True(t, tr.Has(anotherOffsetInSameBlock), "Expected offset %d to be marked as in same block as %d", anotherOffsetInSameBlock, misalignedOffset) + + // But not for a different block + offsetInNextBlock := int64(pageSize) + assert.False(t, tr.Has(offsetInNextBlock), "Did not expect offset %d to be marked", offsetInNextBlock) + + // And not far outside any set block + offsetFar := int64(2 * pageSize) + assert.False(t, tr.Has(offsetFar), "Did not expect offset %d to be marked", offsetFar) +} + +func TestTracker_Offsets(t *testing.T) { + const pageSize = 4096 + tr := NewTracker(pageSize) + + numOffsets := 300 + + offsetsMap := map[int64]struct{}{} + + for range numOffsets { + select { + case <-t.Context().Done(): + t.FailNow() + default: + } + + base := int64(rand.Intn(121)) // 0..120 + offset := base * pageSize + + offsetsMap[offset] = struct{}{} + tr.Add(offset) + } + + expectedOffsets := slices.Collect(maps.Keys(offsetsMap)) + actualOffsets := slices.Collect(tr.Offsets()) + + assert.Len(t, actualOffsets, len(expectedOffsets)) + assert.ElementsMatch(t, expectedOffsets, actualOffsets) +} diff --git a/packages/orchestrator/internal/sandbox/build/diff.go b/packages/orchestrator/internal/sandbox/build/diff.go index e2e929150f..87f63a2194 100644 --- a/packages/orchestrator/internal/sandbox/build/diff.go +++ b/packages/orchestrator/internal/sandbox/build/diff.go @@ -62,3 +62,7 @@ func (n *NoDiff) CacheKey() DiffStoreKey { func (n *NoDiff) Init(context.Context) error { return NoDiffError{} } + +func (n *NoDiff) BlockSize() int64 { + return 0 +} diff --git a/packages/orchestrator/internal/sandbox/build/local_diff.go b/packages/orchestrator/internal/sandbox/build/local_diff.go index d5526ef338..c192e52778 100644 --- a/packages/orchestrator/internal/sandbox/build/local_diff.go +++ b/packages/orchestrator/internal/sandbox/build/local_diff.go @@ -137,3 +137,7 @@ func (b *localDiff) CacheKey() DiffStoreKey { func (b *localDiff) Init(context.Context) error { return nil } + +func (b *localDiff) BlockSize() int64 { + return b.blockSize +} diff --git a/packages/orchestrator/internal/sandbox/build/storage_diff.go b/packages/orchestrator/internal/sandbox/build/storage_diff.go index 5b6fe0e7a0..02723401d5 100644 --- a/packages/orchestrator/internal/sandbox/build/storage_diff.go +++ b/packages/orchestrator/internal/sandbox/build/storage_diff.go @@ -158,3 +158,7 @@ func (b *StorageDiff) FileSize() (int64, error) { return c.FileSize() } + +func (b *StorageDiff) BlockSize() int64 { + return b.blockSize +} diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index 8df45352cf..994e0708c4 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -672,7 +672,7 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { return fmt.Errorf("failed to pause VM: %w", err) } - if err := s.memory.Disable(); err != nil { + if _, err := s.memory.Disable(ctx); err != nil { return fmt.Errorf("failed to disable uffd: %w", err) } @@ -750,8 +750,12 @@ func (s *Sandbox) Pause( return nil, fmt.Errorf("failed to pause VM: %w", err) } - if err := s.memory.Disable(); err != nil { - return nil, fmt.Errorf("failed to disable uffd: %w", err) + // This disables the uffd and returns the dirty pages. + // With FC async io engine, there can be some further writes to the memory during the actual create snapshot process, + // but as we are still including even read pages as dirty so this should not introduce more bugs right now. + dirty, err := s.memory.Disable(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get dirty pages: %w", err) } // Snapfile is not closed as it's returned and cached for later use (like resume) @@ -798,7 +802,7 @@ func (s *Sandbox) Pause( originalMemfile.Header(), &MemoryDiffCreator{ memfile: memfile, - dirtyPages: s.memory.Dirty(), + dirtyPages: dirty.BitSet(), blockSize: originalMemfile.BlockSize(), doneHook: func(context.Context) error { return memfile.Close() @@ -979,7 +983,7 @@ func serveMemory( ctx, span := tracer.Start(ctx, "serve-memory") defer span.End() - fcUffd, err := uffd.New(memfile, socketPath, memfile.BlockSize()) + fcUffd, err := uffd.New(memfile, socketPath) if err != nil { return nil, fmt.Errorf("failed to create uffd: %w", err) } diff --git a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go index 688cff5331..f4b20ad5fe 100644 --- a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go +++ b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go @@ -7,6 +7,8 @@ import ( "sync" ) +var ErrFdExit = errors.New("fd exit signal") + // FdExit is a wrapper around a pipe that allows to signal the exit of the uffd. type FdExit struct { r *os.File diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go index 824a1e4adb..c10e473231 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -21,10 +21,10 @@ func NewMapping(regions []Region) *Mapping { } // GetOffset returns the relative offset and the page size of the mapped range for a given address. -func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uint64, error) { +func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) { for _, r := range m.Regions { if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.endHostVirtAddr() { - return r.shiftedOffset(hostVirtAddr), uint64(r.PageSize), nil + return r.shiftedOffset(hostVirtAddr), r.PageSize, nil } } diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go index 7b7d87e06f..83382a3162 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go @@ -30,7 +30,7 @@ func TestMapping_GetOffset(t *testing.T) { name string hostVirtAddr uintptr expectedOffset int64 - expectedSize uint64 + expectedSize uintptr expectError error }{ { @@ -134,6 +134,7 @@ func TestMapping_OverlappingRegions(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) // The first matching region should be returned @@ -141,14 +142,14 @@ func TestMapping_OverlappingRegions(t *testing.T) { require.NoError(t, err) // Should get result from first region - require.Equal(t, int64(0x5000+(0x2500-0x1000)), offset) // 0x6500 - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x5000+(0x2500-0x1000)), offset) // 0x6500 + assert.Equal(t, uintptr(header.PageSize), size) // Also test that the underlying implementation prefers the first region if both regions contain the address offset2, size2, err2 := mapping.GetOffset(0x2000) require.NoError(t, err2) - require.Equal(t, int64(0x5000+(0x2000-0x1000)), offset2) // 0x6000 from first region - require.Equal(t, uint64(header.PageSize), size2) + assert.Equal(t, int64(0x5000+(0x2000-0x1000)), offset2) // 0x6000 from first region + assert.Equal(t, uintptr(header.PageSize), size2) } func TestMapping_BoundaryConditions(t *testing.T) { @@ -160,17 +161,18 @@ func TestMapping_BoundaryConditions(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) // Test exact start boundary offset, _, err := mapping.GetOffset(0x1000) require.NoError(t, err) - require.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) + assert.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) // Test just before end boundary (exclusive) offset, _, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 require.NoError(t, err) - require.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF + assert.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF // Test exact end boundary (should fail - exclusive) _, _, err = mapping.GetOffset(0x3000) // 0x1000 + 0x2000 @@ -195,8 +197,8 @@ func TestMapping_SingleLargeRegion(t *testing.T) { offset, size, err := mapping.GetOffset(0xABCDEF) require.NoError(t, err) - require.Equal(t, int64(0x100+0xABCDEF), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x100+0xABCDEF), offset) + assert.Equal(t, uintptr(header.PageSize), size) } func TestMapping_ZeroSizeRegion(t *testing.T) { @@ -208,7 +210,9 @@ func TestMapping_ZeroSizeRegion(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) + _, _, err := mapping.GetOffset(0x2000) require.Error(t, err) } @@ -229,17 +233,18 @@ func TestMapping_MultipleRegionsSparse(t *testing.T) { }, } mapping := NewMapping(regions) + // Should succeed for start of first region offset, size, err := mapping.GetOffset(0x100) require.NoError(t, err) - require.Equal(t, int64(0x1000), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x1000), offset) + assert.Equal(t, uintptr(header.PageSize), size) // Should succeed for start of second region offset, size, err = mapping.GetOffset(0x10000) require.NoError(t, err) - require.Equal(t, int64(0x2000), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x2000), offset) + assert.Equal(t, uintptr(header.PageSize), size) // In gap _, _, err = mapping.GetOffset(0x5000) diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go index b3deab2006..db1d4f8a3c 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/region.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -8,7 +8,7 @@ type Region struct { Size uintptr `json:"size"` Offset uintptr `json:"offset"` // This field is deprecated in the newer version of the Firecracker with a new field `page_size`. - PageSize uintptr `json:"page_size_kib"` + PageSize uintptr `json:"page_size_kib"` // This is actually in bytes in the deprecated version. } // endHostVirtAddr returns the end address of the region in host virtual address. diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index 4c65f5d977..cf64c83620 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -3,14 +3,14 @@ package uffd import ( "context" - "github.com/bits-and-blooms/bitset" - + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) type MemoryBackend interface { - Disable() error - Dirty() *bitset.BitSet + // Disable unregisters the uffd from the memory mapping and returns the dirty pages. + // It must be called after FC pause finished and before FC snapshot is created. + Disable(ctx context.Context) (*block.Tracker, error) Start(ctx context.Context, sandboxId string) error Stop() error diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index 1e76339ac9..c859c333e4 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -5,6 +5,7 @@ import ( "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -13,7 +14,7 @@ type NoopMemory struct { size int64 blockSize int64 - dirty *bitset.BitSet + dirty *block.Tracker exit *utils.ErrorOnce } @@ -23,23 +24,19 @@ var _ MemoryBackend = (*NoopMemory)(nil) func NewNoopMemory(size, blockSize int64) *NoopMemory { blocks := header.TotalBlocks(size, blockSize) - dirty := bitset.New(uint(blocks)) - dirty.FlipRange(0, dirty.Len()) + b := bitset.New(uint(blocks)) + b.FlipRange(0, b.Len()) return &NoopMemory{ size: size, blockSize: blockSize, - dirty: dirty, + dirty: block.NewTrackerFromBitset(b, blockSize), exit: utils.NewErrorOnce(), } } -func (m *NoopMemory) Disable() error { - return nil -} - -func (m *NoopMemory) Dirty() *bitset.BitSet { - return m.dirty +func (m *NoopMemory) Disable(context.Context) (*block.Tracker, error) { + return m.dirty.Clone(), nil } func (m *NoopMemory) Start(context.Context, string) error { diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go b/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go index 27d786199e..c3b1fee494 100644 --- a/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/memory_slicer.go @@ -33,3 +33,7 @@ func (s *MemorySlicer) Size() (int64, error) { func (s *MemorySlicer) Content() []byte { return s.content } + +func (s *MemorySlicer) BlockSize() int64 { + return s.pagesize +} diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index 707aac6ae5..b7b77e909e 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -10,7 +10,6 @@ import ( "syscall" "time" - "github.com/bits-and-blooms/bitset" "go.opentelemetry.io/otel" "go.uber.org/zap" @@ -36,18 +35,13 @@ type Uffd struct { fdExit *fdexit.FdExit lis *net.UnixListener socketPath string - memfile *block.TrackedSliceDevice + memfile block.ReadonlyDevice handler utils.SetOnce[*userfaultfd.Userfaultfd] } var _ MemoryBackend = (*Uffd)(nil) -func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uffd, error) { - trackedMemfile, err := block.NewTrackedSliceDevice(blockSize, memfile) - if err != nil { - return nil, fmt.Errorf("failed to create tracked slice device: %w", err) - } - +func New(memfile block.ReadonlyDevice, socketPath string) (*Uffd, error) { fdExit, err := fdexit.New() if err != nil { return nil, fmt.Errorf("failed to create fd exit: %w", err) @@ -58,7 +52,7 @@ func New(memfile block.ReadonlyDevice, socketPath string, blockSize int64) (*Uff readyCh: make(chan struct{}, 1), fdExit: fdExit, socketPath: socketPath, - memfile: trackedMemfile, + memfile: memfile, handler: *utils.NewSetOnce[*userfaultfd.Userfaultfd](), }, nil } @@ -144,7 +138,7 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { m := memory.NewMapping(regions) uffd, err := userfaultfd.NewUserfaultfdFromFd( - uintptr(fds[0]), + userfaultfd.Fd(fds[0]), u.memfile, m, logger.L().With(logger.WithSandboxID(sandboxId)), @@ -168,6 +162,10 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { ctx, u.fdExit, ) + if errors.Is(err, fdexit.ErrFdExit) { + return nil + } + if err != nil { return fmt.Errorf("failed handling uffd: %w", err) } @@ -187,10 +185,34 @@ func (u *Uffd) Exit() *utils.ErrorOnce { return u.exit } -func (u *Uffd) Disable() error { - return u.memfile.Disable() +// Disable unregisters the uffd from the memory mapping, +// allowing us to create a "diff" snapshot via FC API without dirty tracking enabled, +// and without pagefaulting all remaining missing pages. +// +// After calling Disable(), this uffd is no longer usable—we won't be able to resume the sandbox via API. +// The uffd itself is not closed though, as that should be done by the sandbox cleanup. +func (u *Uffd) Disable(ctx context.Context) (*block.Tracker, error) { + uffd, err := u.handler.WaitWithContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get uffd: %w", err) + } + + err = uffd.Unregister() + if err != nil { + return nil, fmt.Errorf("failed to unregister uffd: %w", err) + } + + return u.dirty(ctx) } -func (u *Uffd) Dirty() *bitset.BitSet { - return u.memfile.Dirty() +// Dirty waits for the current requests to finish and returns the dirty pages. +// +// It *MUST* be only called after the sandbox was successfully paused via API. +func (u *Uffd) dirty(ctx context.Context) (*block.Tracker, error) { + uffd, err := u.handler.WaitWithContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get uffd: %w", err) + } + + return uffd.Dirty(), nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 01a745c9eb..ef184d126f 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -23,8 +23,6 @@ import ( "fmt" "syscall" "unsafe" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) const ( @@ -34,12 +32,19 @@ const ( UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING + UFFDIO_REGISTER_MODE_WP = C.UFFDIO_REGISTER_MODE_WP + + UFFDIO_WRITEPROTECT_MODE_WP = C.UFFDIO_WRITEPROTECT_MODE_WP + UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP - UFFDIO_API = C.UFFDIO_API - UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER + UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE + UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) @@ -66,13 +71,24 @@ func newUffdioAPI(api, features CULong) UffdioAPI { } } +func newUffdioRange(start, length CULong) UffdioRange { + return UffdioRange{ + start: start, + len: length, + } +} + func newUffdioRegister(start, length, mode CULong) UffdioRegister { return UffdioRegister{ - _range: UffdioRange{ - start: start, - len: length, - }, - mode: mode, + _range: newUffdioRange(start, length), + mode: mode, + } +} + +func newUffdioWriteProtect(start, length, mode CULong) UffdioWriteProtect { + return UffdioWriteProtect{ + _range: newUffdioRange(start, length), + mode: mode, } } @@ -98,47 +114,29 @@ func getPagefaultAddress(pagefault *UffdPagefault) uintptr { return uintptr(pagefault.address) } -// uffdFd is a helper type that wraps uffd fd. -type uffdFd uintptr +// Fd is a helper type that wraps uffd fd. +type Fd uintptr -// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK -func newFd(flags uintptr) (uffdFd, error) { - uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) - if errno != 0 { - return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) - } - - return uffdFd(uffd), nil -} - -// features: UFFD_FEATURE_MISSING_HUGETLBFS -// This is already called by the FC -func (u uffdFd) configureApi(pagesize uint64) error { - var features CULong - - // Only set the hugepage feature if we're using hugepages - if pagesize == header.HugepageSize { - features |= UFFD_FEATURE_MISSING_HUGETLBFS - } +// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING +// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING +// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp +func (f Fd) register(addr uintptr, size uint64, mode CULong) error { + register := newUffdioRegister(CULong(addr), CULong(size), mode) - api := newUffdioAPI(UFFD_API, features) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_API, uintptr(unsafe.Pointer(&api))) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) if errno != 0 { - return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) } return nil } -// mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING -// This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING -// We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { - register := newUffdioRegister(CULong(addr), CULong(size), mode) +func (f Fd) unregister(addr, size uintptr) error { + r := newUffdioRange(CULong(addr), CULong(size)) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_UNREGISTER, uintptr(unsafe.Pointer(&r))) if errno != 0 { - return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) + return fmt.Errorf("UFFDIO_UNREGISTER ioctl failed: %w (ret=%d)", errno, ret) } return nil @@ -146,21 +144,38 @@ func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { // mode: UFFDIO_COPY_MODE_WP // When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page -func (u uffdFd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { +func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { cpy := newUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { return errno } // Check if the copied size matches the requested pagesize - if uint64(cpy.copy) != pagesize { + if cpy.copy != CLong(pagesize) { return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) } return nil } -func (u uffdFd) close() error { - return syscall.Close(int(u)) +// mode: UFFDIO_WRITEPROTECT_MODE_WP +// Passing 0 as the mode will remove the write protection. +func (f Fd) writeProtect(addr, size uintptr, mode CULong) error { + register := newUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (f Fd) close() error { + return syscall.Close(int(f)) +} + +func (f Fd) fd() int32 { + return int32(f) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go new file mode 100644 index 0000000000..367af9e2aa --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -0,0 +1,122 @@ +package userfaultfd + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// mockFd is a mock implementation of the Fd interface. +// It allows us to test the handling methods separately from the actual uffd serve loop. +type mockFd struct { + // The channels send back the info about the uffd handled operations + // and also allows us to block the methods to test the flow. + copyCh chan *blockedEvent[UffdioCopy] + writeProtectCh chan *blockedEvent[UffdioWriteProtect] +} + +func newMockFd() *mockFd { + return &mockFd{ + copyCh: make(chan *blockedEvent[UffdioCopy]), + writeProtectCh: make(chan *blockedEvent[UffdioWriteProtect]), + } +} + +func (m *mockFd) register(_ uintptr, _ uint64, _ CULong) error { + return nil +} + +func (m *mockFd) unregister(_, _ uintptr) error { + return nil +} + +func (m *mockFd) copy(addr, pagesize uintptr, _ []byte, mode CULong) error { + // Don't use the uffdioCopy constructor as it unsafely checks slice address and fails for arbitrary pointer. + e := newBlockedEvent(UffdioCopy{ + src: 0, + dst: CULong(addr), + len: CULong(pagesize), + mode: mode, + copy: 0, + }) + + m.copyCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) writeProtect(addr, size uintptr, mode CULong) error { + e := newBlockedEvent(UffdioWriteProtect{ + _range: newUffdioRange( + CULong(addr), + CULong(size), + ), + mode: mode, + }) + + m.writeProtectCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) close() error { + return nil +} + +func (m *mockFd) fd() int32 { + return 0 +} + +// Used for testing. +// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK +func newFd(flags uintptr) (Fd, error) { + uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) + if errno != 0 { + return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) + } + + return Fd(uffd), nil +} + +// Used for testing +// features: UFFD_FEATURE_MISSING_HUGETLBFS +// This is already called by the FC +func configureApi(f Fd, pagesize uint64) error { + var features CULong + + // Only set the hugepage feature if we're using hugepages + if pagesize == header.HugepageSize { + features |= UFFD_FEATURE_MISSING_HUGETLBFS + } + + api := newUffdioAPI(UFFD_API, features) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_API, uintptr(unsafe.Pointer(&api))) + if errno != 0 { + return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// This wrapped event allows us to simulate the finish of processing of events by FC on FC API Pause. +type blockedEvent[T UffdioCopy | UffdioWriteProtect] struct { + event T + resolved chan struct{} +} + +func newBlockedEvent[T UffdioCopy | UffdioWriteProtect](event T) *blockedEvent[T] { + return &blockedEvent[T]{ + event: event, + resolved: make(chan struct{}), + } +} + +func (e *blockedEvent[T]) resolve() { + close(e.resolved) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index 18d82b7849..0bd2b74944 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -1,5 +1,10 @@ package userfaultfd +// flowchart TD +// A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +// A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +// E -- write (WP+[WRITE] flag) --> F(remove MODE_WP) --> C + import ( "context" "errors" @@ -18,15 +23,31 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) +const maxRequestsInProgress = 4096 + var ErrUnexpectedEventType = errors.New("unexpected event type") +type uffdio interface { + unregister(addr, size uintptr) error + register(addr uintptr, size uint64, mode CULong) error + copy(addr, pagesize uintptr, data []byte, mode CULong) error + writeProtect(addr, size uintptr, mode CULong) error + close() error + fd() int32 +} + type Userfaultfd struct { - fd uffdFd + uffd uffdio src block.Slicer ma *memory.Mapping - missingRequests sync.Map + // We don't skip the already mapped pages, because if the memory is swappable the page *might* under some conditions be mapped out. + // For hugepages this should not be a problem, but might theoretically happen to normal pages with swap + missingRequests *block.Tracker + writeRequests *block.Tracker + // We use the settleRequests to guard the missingRequests so we can access a consistent state of the missingRequests after the requests are finished. + settleRequests sync.RWMutex wg errgroup.Group @@ -34,26 +55,57 @@ type Userfaultfd struct { } // NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. -func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { - return &Userfaultfd{ - fd: uffdFd(fd), +func NewUserfaultfdFromFd(uffd uffdio, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { + blockSize := src.BlockSize() + + for _, region := range m.Regions { + if region.PageSize != uintptr(blockSize) { + return nil, fmt.Errorf("block size mismatch: %d != %d for region %d", region.PageSize, blockSize, region.BaseHostVirtAddr) + } + + // Register the WP for the regions. + // The memory region is already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these registration flags. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := uffd.register( + region.BaseHostVirtAddr, + uint64(region.Size), + UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING, + ) + if err != nil { + return nil, fmt.Errorf("failed to reregister memory region with write protection %d-%d: %w", region.Offset, region.Offset+region.Size, err) + } + } + + u := &Userfaultfd{ + uffd: uffd, src: src, - missingRequests: sync.Map{}, + missingRequests: block.NewTracker(blockSize), + writeRequests: block.NewTracker(blockSize), ma: m, logger: logger, - }, nil + } + + // By default this was unlimited. + // Now that we don't skip previously faulted pages we add at least some boundaries to the concurrency. + // Also, in some brief tests, adding a limit actually improved the handling at high concurrency. + u.wg.SetLimit(maxRequestsInProgress) + + return u, nil } func (u *Userfaultfd) Close() error { - return u.fd.close() + return u.uffd.close() } func (u *Userfaultfd) Serve( ctx context.Context, fdExit *fdexit.FdExit, ) error { + uffd := u.uffd.fd() + pollFds := []unix.PollFd{ - {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: uffd, Events: unix.POLLIN}, {Fd: fdExit.Reader(), Events: unix.POLLIN}, } @@ -92,7 +144,7 @@ outerLoop: return fmt.Errorf("failed to handle uffd: %w", errMsg) } - return nil + return fdexit.ErrFdExit } uffdFd := pollFds[0] @@ -114,7 +166,7 @@ outerLoop: buf := make([]byte, unsafe.Sizeof(UffdMsg{})) for { - _, err := syscall.Read(int(u.fd), buf) + _, err := syscall.Read(int(uffd), buf) if err == syscall.EINTR { u.logger.Debug(ctx, "uffd: interrupted read, reading again") @@ -162,14 +214,20 @@ outerLoop: return fmt.Errorf("failed to map: %w", err) } + // Handle write to write protected page (WP flag) + // The documentation does not clearly mention if the WRITE flag must be present with the WP flag, even though we saw it being present in the events. + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + u.handleWriteProtected(ctx, fdExit.SignalExit, addr, pagesize, offset) + + continue + } + // Handle write to missing page (WRITE flag) // If the event has WRITE flag, it was a write to a missing page. // For the write to be executed, we first need to copy the page from the source to the guest memory. if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) - if err != nil { - return fmt.Errorf("failed to handle missing write: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, true) continue } @@ -177,10 +235,7 @@ outerLoop: // Handle read to missing page ("MISSING" flag) // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. if flags == 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) - if err != nil { - return fmt.Errorf("failed to handle missing: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, false) continue } @@ -193,20 +248,27 @@ outerLoop: func (u *Userfaultfd) handleMissing( ctx context.Context, onFailure func() error, - addr uintptr, + addr, + pagesize uintptr, offset int64, - pagesize uint64, -) error { - if _, ok := u.missingRequests.Load(offset); ok { - return nil - } - - u.missingRequests.Store(offset, struct{}{}) - + write bool, +) { u.wg.Go(func() error { + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, + // even if the errgroup is cancelled or the goroutine returns early. + // This check protects us against race condition between marking the request as missing and accessing the missingRequests tracker. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + defer func() { if r := recover(); r != nil { u.logger.Error(ctx, "UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle missing failure error", zap.Error(signalErr)) + } } }() @@ -223,7 +285,12 @@ func (u *Userfaultfd) handleMissing( var copyMode CULong - copyErr := u.fd.copy(addr, b, pagesize, copyMode) + // If the event is not WRITE, we need to add WP to the page, so we can catch the next WRITE+WP and mark the page as dirty. + if !write { + copyMode |= UFFDIO_COPY_MODE_WP + } + + copyErr := u.uffd.copy(addr, pagesize, b, copyMode) if errors.Is(copyErr, unix.EEXIST) { // Page is already mapped @@ -237,11 +304,80 @@ func (u *Userfaultfd) handleMissing( u.logger.Error(ctx, "UFFD serve uffdio copy error", zap.Error(joinedErr)) - return fmt.Errorf("failed uffdio copy %w", joinedErr) + return fmt.Errorf("failed to copy page %d-%d %w", offset, offset+int64(pagesize), joinedErr) + } + + // Add the offset to the missing requests tracker. + u.missingRequests.Add(offset) + + if write { + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + } + + return nil + }) +} + +// Userfaultfd write-protect mode currently behave differently on none ptes (when e.g. page is missing) over different types of memories (hugepages file backed, etc.). +// - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications - "there will be a userfaultfd write fault message generated when writing to a missing page" +// This should not affect the handling we have in place as all events are being handled. +func (u *Userfaultfd) handleWriteProtected(ctx context.Context, onFailure func() error, addr, pagesize uintptr, offset int64) { + u.wg.Go(func() error { + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, + // even if the errgroup is cancelled or the goroutine returns early. + // This check protects us against race condition between marking the request as dirty and accessing the writeRequests tracker. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + + defer func() { + if r := recover(); r != nil { + u.logger.Error(ctx, "UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle write protected failure error", zap.Error(signalErr)) + } + } + }() + + // Passing 0 as the mode removes the write protection. + wpErr := u.uffd.writeProtect(addr, pagesize, 0) + if wpErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(wpErr, signalErr) + + u.logger.Error(ctx, "UFFD serve write protect error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to remove write protection from page %d-%d %w", offset, offset+int64(pagesize), joinedErr) } + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + return nil }) +} + +func (u *Userfaultfd) Unregister() error { + for _, r := range u.ma.Regions { + if err := u.uffd.unregister(r.BaseHostVirtAddr, r.Size); err != nil { + return fmt.Errorf("failed to unregister: %w", err) + } + } return nil } + +// Dirty returns the dirty pages. +func (u *Userfaultfd) Dirty() *block.Tracker { + // This will be at worst cancelled when the uffd is closed. + u.settleRequests.Lock() + // The locking here would work even without using defer (just lock-then-unlock the mutex), but at this point let's make it lock to the clone, + // so it is consistent even if there is a another uffd call after. + defer u.settleRequests.Unlock() + + return u.writeRequests.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_helpers_test.go similarity index 75% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_helpers_test.go index 1a65664b86..1f1b61ca72 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_helpers_test.go @@ -9,6 +9,7 @@ import ( "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" ) @@ -41,8 +42,25 @@ type testHandler struct { data *testutils.MemorySlicer // Returns offsets of the pages that were faulted. // It can only be called once. - offsetsOnce func() ([]uint, error) - mutex sync.Mutex + // Sorted in ascending order. + accessedOffsetsOnce func() ([]uint, error) + // Returns offsets of the pages that were dirtied. + // It can only be called once. + // Sorted in ascending order. + dirtyOffsetsOnce func() ([]uint, error) + + mutex sync.Mutex +} + +func (h *testHandler) executeOperation(ctx context.Context, op operation) error { + switch op.mode { + case operationModeRead: + return h.executeRead(ctx, op) + case operationModeWrite: + return h.executeWrite(ctx, op) + default: + return fmt.Errorf("invalid operation mode: %d", op.mode) + } } func (h *testHandler) executeRead(ctx context.Context, op operation) error { @@ -82,6 +100,7 @@ func (h *testHandler) executeWrite(ctx context.Context, op operation) error { } // Get a bitset of the offsets of the operations for the given mode. +// Sorted in ascending order. func getOperationsOffsets(ops []operation, m operationMode) []uint { b := bitset.New(0) @@ -93,3 +112,10 @@ func getOperationsOffsets(ops []operation, m operationMode) []uint { return slices.Collect(b.EachSet()) } + +func accessed(u *Userfaultfd) *block.Tracker { + u.settleRequests.Lock() + defer u.settleRequests.Unlock() + + return u.missingRequests.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_test.go similarity index 73% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_test.go index 20c8ddeeb3..cf24e55d51 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_test.go @@ -62,6 +62,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "standard 4k page, reads, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + }, + }, { name: "hugepage, operation at start", pagesize: header.HugepageSize, @@ -110,6 +137,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "hugepage, reads, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, } for _, tt := range tests { @@ -120,15 +174,13 @@ func TestMissing(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -158,7 +210,7 @@ func TestParallelMissing(t *testing.T) { for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -167,7 +219,7 @@ func TestParallelMissing(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -191,14 +243,14 @@ func TestParallelMissingWithPrefault(t *testing.T) { mode: operationModeRead, } - err = h.executeRead(t.Context(), readOp) + err = h.executeOperation(t.Context(), readOp) require.NoError(t, err) var verr errgroup.Group for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -207,7 +259,7 @@ func TestParallelMissingWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -232,13 +284,13 @@ func TestSerialMissing(t *testing.T) { } for range serialOperations { - err := h.executeRead(t.Context(), readOp) + err := h.executeOperation(t.Context(), readOp) require.NoError(t, err) } expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_write_test.go similarity index 55% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_write_test.go index 52a73e5fa2..7505dcc6d8 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_write_test.go @@ -10,6 +10,8 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// TODO: Investigate flakyness +// TODO: It is possible the hugepages trigger the automatic WP func TestMissingWrite(t *testing.T) { t.Parallel() @@ -62,6 +64,68 @@ func TestMissingWrite(t *testing.T) { }, }, }, + { + name: "standard 4k page, write after write", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, reads after writes, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + }, + }, { name: "hugepage, operation at start", pagesize: header.HugepageSize, @@ -110,6 +174,68 @@ func TestMissingWrite(t *testing.T) { }, }, }, + { + name: "hugepage, write after write", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, reads after writes, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, } for _, tt := range tests { @@ -120,23 +246,23 @@ func TestMissingWrite(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } - - if operation.mode == operationModeWrite { - err := h.executeWrite(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") }) } } @@ -172,10 +298,17 @@ func TestParallelMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } func TestParallelMissingWriteWithPrefault(t *testing.T) { @@ -212,10 +345,17 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } func TestSerialMissingWrite(t *testing.T) { @@ -243,8 +383,15 @@ func TestSerialMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_process_helpers_test.go similarity index 62% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_process_helpers_test.go index d61b903fd3..ac711d1e80 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_process_helpers_test.go @@ -16,8 +16,6 @@ import ( "os/exec" "os/signal" "strconv" - "strings" - "sync" "syscall" "testing" @@ -51,13 +49,15 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error uffdFd.close() }) - err = uffdFd.configureApi(tt.pagesize) + err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) - cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess") + // We don't use t.Context() here, because we want to be able to kill the process manually and listen to the correct exit code, + // while also handling the cleanup of the uffd. The t.Context seems to trigger before the test cleanup is started. + cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperServingProcess") cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) @@ -82,11 +82,18 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error assert.NoError(t, closeErr) }() - offsetsReader, offsetsWriter, err := os.Pipe() + accessedOffsetsReader, accessedOffsetsWriter, err := os.Pipe() require.NoError(t, err) t.Cleanup(func() { - offsetsReader.Close() + accessedOffsetsReader.Close() + }) + + dirtyOffsetsReader, dirtyOffsetsWriter, err := os.Pipe() + require.NoError(t, err) + + t.Cleanup(func() { + dirtyOffsetsReader.Close() }) readyReader, readyWriter, err := os.Pipe() @@ -107,8 +114,9 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error cmd.ExtraFiles = []*os.File{ uffdFile, contentReader, - offsetsWriter, + accessedOffsetsWriter, readyWriter, + dirtyOffsetsWriter, } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -117,36 +125,59 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error require.NoError(t, err) contentReader.Close() - offsetsWriter.Close() + accessedOffsetsWriter.Close() readyWriter.Close() uffdFile.Close() + dirtyOffsetsWriter.Close() + + go func() { + waitErr := cmd.Wait() + assert.NoError(t, waitErr) + + assert.NotEqual(t, -1, cmd.ProcessState.ExitCode(), "process was not terminated gracefully") + assert.NotEqual(t, 2, cmd.ProcessState.ExitCode(), "fd exit prematurely terminated the serve loop") + assert.NotEqual(t, 1, cmd.ProcessState.ExitCode(), "process exited with unexpected exit code") + + assert.Equal(t, 0, cmd.ProcessState.ExitCode()) + }() t.Cleanup(func() { - signalErr := cmd.Process.Signal(syscall.SIGUSR1) + // We are using SIGHUP to actually get exit code, not -1. + signalErr := cmd.Process.Signal(syscall.SIGTERM) assert.NoError(t, signalErr) - - waitErr := cmd.Wait() - // It can be either nil, an ExitError, a context.Canceled error, or "signal: killed" - assert.True(t, - (waitErr != nil && func(err error) bool { - var exitErr *exec.ExitError - - return errors.As(err, &exitErr) - }(waitErr)) || - errors.Is(waitErr, context.Canceled) || - (waitErr != nil && strings.Contains(waitErr.Error(), "signal: killed")) || - waitErr == nil, - "unexpected error: %v", waitErr, - ) }) - offsetsOnce := func() ([]uint, error) { + accessedOffsetsOnce := func() ([]uint, error) { err := cmd.Process.Signal(syscall.SIGUSR2) if err != nil { return nil, err } - offsetsBytes, err := io.ReadAll(offsetsReader) + offsetsBytes, err := io.ReadAll(accessedOffsetsReader) + if err != nil { + return nil, err + } + + var offsetList []uint + + if len(offsetsBytes)%8 != 0 { + return nil, fmt.Errorf("invalid offsets bytes length: %d", len(offsetsBytes)) + } + + for i := 0; i < len(offsetsBytes); i += 8 { + offsetList = append(offsetList, uint(binary.LittleEndian.Uint64(offsetsBytes[i:i+8]))) + } + + return offsetList, nil + } + + dirtyOffsetsOnce := func() ([]uint, error) { + err := cmd.Process.Signal(syscall.SIGUSR1) + if err != nil { + return nil, err + } + + offsetsBytes, err := io.ReadAll(dirtyOffsetsReader) if err != nil { return nil, err } @@ -171,22 +202,27 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error } return &testHandler{ - memoryArea: &memoryArea, - pagesize: tt.pagesize, - data: data, - offsetsOnce: offsetsOnce, + memoryArea: &memoryArea, + pagesize: tt.pagesize, + data: data, + accessedOffsetsOnce: accessedOffsetsOnce, + dirtyOffsetsOnce: dirtyOffsetsOnce, }, nil } -// Secondary process, orchestrator in our case +// Secondary process, orchestrator in our case. func TestHelperServingProcess(t *testing.T) { if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { t.Skip("this is a helper process, skipping direct execution") } err := crossProcessServe() + if errors.Is(err, fdexit.ErrFdExit) { + os.Exit(2) + } + if err != nil { - fmt.Println("exit serving process", err) + fmt.Fprintf(os.Stderr, "error serving: %v", err) os.Exit(1) } @@ -241,40 +277,61 @@ func crossProcessServe() error { return fmt.Errorf("exit creating logger: %w", err) } - uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, l) + uffd, err := NewUserfaultfdFromFd(Fd(uffdFd), data, m, l) if err != nil { return fmt.Errorf("exit creating uffd: %w", err) } - offsetsFile := os.NewFile(uintptr(5), "offsets") + accessedOffsetsFile := os.NewFile(uintptr(5), "accessed-offsets") - offsetsSignal := make(chan os.Signal, 1) - signal.Notify(offsetsSignal, syscall.SIGUSR2) - defer signal.Stop(offsetsSignal) + accessedOffsestsSignal := make(chan os.Signal, 1) + signal.Notify(accessedOffsestsSignal, syscall.SIGUSR2) + defer signal.Stop(accessedOffsestsSignal) go func() { - defer offsetsFile.Close() + defer accessedOffsetsFile.Close() for { select { case <-ctx.Done(): return - case <-offsetsSignal: - offsets, err := getAccessedOffsets(&uffd.missingRequests) - if err != nil { - msg := fmt.Errorf("error getting accessed offsets from cross process: %w", err) + case <-accessedOffsestsSignal: + for offset := range accessed(uffd).Offsets() { + writeErr := binary.Write(accessedOffsetsFile, binary.LittleEndian, uint64(offset)) + if writeErr != nil { + msg := fmt.Errorf("error writing accessed offsets to file: %w", writeErr) - fmt.Fprint(os.Stderr, msg.Error()) + fmt.Fprint(os.Stderr, msg.Error()) - cancel(msg) + cancel(msg) - return + return + } } - for _, offset := range offsets { - writeErr := binary.Write(offsetsFile, binary.LittleEndian, uint64(offset)) + return + } + } + }() + + dirtyOffsetsFile := os.NewFile(uintptr(7), "dirty-offsets") + + dirtyOffsetsSignal := make(chan os.Signal, 1) + signal.Notify(dirtyOffsetsSignal, syscall.SIGUSR1) + defer signal.Stop(dirtyOffsetsSignal) + + go func() { + defer dirtyOffsetsFile.Close() + + for { + select { + case <-ctx.Done(): + return + case <-dirtyOffsetsSignal: + for offset := range uffd.Dirty().Offsets() { + writeErr := binary.Write(dirtyOffsetsFile, binary.LittleEndian, uint64(offset)) if writeErr != nil { - msg := fmt.Errorf("error writing offsets to file: %w", writeErr) + msg := fmt.Errorf("error writing dirty offsets to file: %w", writeErr) fmt.Fprint(os.Stderr, msg.Error()) @@ -301,6 +358,14 @@ func crossProcessServe() error { }() serverErr := uffd.Serve(ctx, fdExit) + if errors.Is(serverErr, fdexit.ErrFdExit) { + err := fmt.Errorf("serving finished via fd exit: %w", serverErr) + + cancel(err) + + return + } + if serverErr != nil { msg := fmt.Errorf("error serving: %w", serverErr) @@ -310,6 +375,8 @@ func crossProcessServe() error { return } + + fmt.Fprint(os.Stderr, "serving finished") }() cleanup := func() { @@ -330,7 +397,7 @@ func crossProcessServe() error { defer cleanup() exitSignal := make(chan os.Signal, 1) - signal.Notify(exitSignal, syscall.SIGUSR1) + signal.Notify(exitSignal, syscall.SIGTERM) defer signal.Stop(exitSignal) readyFile := os.NewFile(uintptr(6), "ready") @@ -342,20 +409,8 @@ func crossProcessServe() error { select { case <-ctx.Done(): - return fmt.Errorf("context done: %w: %w", ctx.Err(), context.Cause(ctx)) + return context.Cause(ctx) case <-exitSignal: return nil } } - -func getAccessedOffsets(missingRequests *sync.Map) ([]uint, error) { - var offsets []uint - - missingRequests.Range(func(key, _ any) bool { - offsets = append(offsets, uint(key.(int64))) - - return true - }) - - return offsets, nil -} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go new file mode 100644 index 0000000000..16333261d3 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -0,0 +1,771 @@ +package userfaultfd + +import ( + "context" + "fmt" + "maps" + "math/rand" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +func TestNoOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(512) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(newMockFd(), h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, accessedOffsets, "checking which pages were faulted") + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + dirty := uffd.Dirty() + assert.Empty(t, slices.Collect(dirty.Offsets()), "checking dirty pages") +} + +func TestRandomOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + numberOfOperations := 2048 + repetitions := 8 + + for i := range repetitions { + t.Run(fmt.Sprintf("Run_%d_of_%d", i+1, repetitions), func(t *testing.T) { + t.Parallel() + + // Use time-based seed for each run to ensure different random sequences + // This increases the chance of catching bugs that only manifest with specific sequences + seed := time.Now().UnixNano() + int64(i) + rng := rand.New(rand.NewSource(seed)) + + t.Logf("Using random seed: %d", seed) + + // Randomly operations on the data + operations := make([]operation, 0, numberOfOperations) + for range numberOfOperations { + operations = append(operations, operation{ + offset: int64(rng.Intn(int(numberOfPages-1)) * int(pagesize)), + mode: operationMode(rng.Intn(2) + 1), + }) + } + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + operations: operations, + }) + require.NoError(t, err) + + for _, operation := range operations { + err := h.executeOperation(t.Context(), operation) + require.NoError(t, err, "for operation %+v", operation) + } + + expectedAccessedOffsets := getOperationsOffsets(operations, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted (seed: %d)", seed) + + expectedDirtyOffsets := getOperationsOffsets(operations, operationModeWrite) + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty (seed: %d)", seed) + }) + } +} + +// Badly configured uffd panic recovery caused silent close of uffd—first operation (or parallel first operations) were always handled ok, but subsequent operations would freeze. +// This behavior was flaky with the rest of the tests, because it was racy. +func TestUffdNotClosingAfterOperation(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + // We need different offset, because kernel would cache the same page for read. + write2 := operation{ + offset: int64(2 * header.PageSize), + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, dirtyOffsets, "checking which pages were dirty") + }) + + t.Run("missing read", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + // We need different offset, because kernel would cache the same page for read. + read2 := operation{ + offset: 2 * header.PageSize, + mode: operationModeRead, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, accessedOffsets, "checking which pages were faulted") + }) + + t.Run("write protected", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + read2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeRead, + } + + // We need at least 2 wp events to check the wp handler, so we need to write to 2 different pages. + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + write2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, dirtyOffsets, "checking which pages were dirty") + }) +} + +func TestUffdEvents(t *testing.T) { + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + events := []event{ + // Same operation and offset, repeated (copies at 0), with both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + + // WriteProtect at same offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + + // Copy at next offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + + // WriteProtect at next offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + + // Copy at another offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + + // WriteProtect at another offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + } + + for _, event := range events { + err := event.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", event) + } + + receivedEvents := make([]event, 0, len(events)) + + for range events { + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + copyEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioCopy: ©Event.event}) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + writeProtectEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioWriteProtect: &writeProtectEvent.event}) + case <-t.Context().Done(): + t.FailNow() + } + } + + assert.Len(t, receivedEvents, len(events), "checking received events") + assert.ElementsMatch(t, zeroOffsets(events), receivedEvents, "checking received events") + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + default: + } + + dirty := uffd.Dirty() + + expectedDirtyOffsets := make(map[int64]struct{}) + expectedAccessedOffsets := make(map[int64]struct{}) + + for _, event := range events { + if event.UffdioWriteProtect != nil { + expectedDirtyOffsets[event.offset] = struct{}{} + } + if event.UffdioCopy != nil { + if event.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + expectedDirtyOffsets[event.offset] = struct{}{} + } + + expectedAccessedOffsets[event.offset] = struct{}{} + } + } + + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedDirtyOffsets)), slices.Collect(dirty.Offsets()), "checking dirty pages") + + accessed := accessed(uffd) + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedAccessedOffsets)), slices.Collect(accessed.Offsets()), "checking accessed pages") +} + +func TestUffdSettleRequests(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + testEventsSettle := func(t *testing.T, events []event) { + t.Helper() + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + for _, e := range events { + err = e.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", e) + } + + var blockedCopyEvents []*blockedEvent[UffdioCopy] + var blockedWriteProtectEvents []*blockedEvent[UffdioWriteProtect] + + for range events { + // Wait until the event is blocked + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, copyEvent.event, "copy event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioCopy: ©Event.event}, "checking copy event") + + blockedCopyEvents = append(blockedCopyEvents, copyEvent) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, writeProtectEvent.event, "write protect event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioWriteProtect: &writeProtectEvent.event}, "checking write protect event") + + blockedWriteProtectEvents = append(blockedWriteProtectEvents, writeProtectEvent) + case <-t.Context().Done(): + t.FailNow() + } + } + + require.Len(t, events, len(blockedCopyEvents)+len(blockedWriteProtectEvents), "checking blocked events") + + simulatedFCPause := make(chan struct{}) + + d := make(chan *block.Tracker) + + go func() { + acquired := uffd.settleRequests.TryLock() + assert.False(t, acquired, "settleRequests write lock should not be acquired") + + simulatedFCPause <- struct{}{} + + // This should block, until the events are resolved. + dirty := uffd.Dirty() + + select { + case d <- dirty: + case <-t.Context().Done(): + return + } + }() + + // This would be the place where the FC API Pause would return. + <-simulatedFCPause + + // Resolve the events to unblock getting the dirty pages in the goroutine. + for _, e := range blockedCopyEvents { + e.resolve() + } + + for _, e := range blockedWriteProtectEvents { + e.resolve() + } + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + case dirty, ok := <-d: + if !ok { + t.FailNow() + } + + assert.ElementsMatch(t, dirtyOffsets(events), slices.Collect(dirty.Offsets()), "checking dirty pages") + } + } + + t.Run("missing", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("write protect", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("event mix", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 0, + }, + } + + testEventsSettle(t, events) + }) +} + +type event struct { + *UffdioCopy + *UffdioWriteProtect + + offset int64 +} + +func (e event) trigger(ctx context.Context, uffd *Userfaultfd) error { + switch { + case e.UffdioCopy != nil: + triggerMissing(ctx, uffd, *e.UffdioCopy, e.offset) + case e.UffdioWriteProtect != nil: + triggerWriteProtected(ctx, uffd, *e.UffdioWriteProtect, e.offset) + default: + return fmt.Errorf("invalid event: %+v", e) + } + + return nil +} + +// Return the event copy without the offset, because the offset is propagated only through the "accessed" and "dirty" sets, so direct comparisons would fail. +func (e event) withoutOffset() event { + return event{ + UffdioCopy: e.UffdioCopy, + UffdioWriteProtect: e.UffdioWriteProtect, + } +} + +// Creates a new slice of events with the offset set to 0, so we can compare the events without the offset. +func zeroOffsets(events []event) []event { + return utils.Map(events, func(e event) event { + return e.withoutOffset() + }) +} + +func triggerMissing(ctx context.Context, uffd *Userfaultfd, c UffdioCopy, offset int64) { + var write bool + + if c.mode != UFFDIO_COPY_MODE_WP { + write = true + } + + uffd.handleMissing( + ctx, + func() error { return nil }, + uintptr(c.dst), + uintptr(uffd.src.BlockSize()), + offset, + write, + ) +} + +func triggerWriteProtected(ctx context.Context, uffd *Userfaultfd, c UffdioWriteProtect, offset int64) { + uffd.handleWriteProtected( + ctx, + func() error { return nil }, + uintptr(c._range.start), + uintptr(uffd.src.BlockSize()), + offset, + ) +} + +func dirtyOffsets(events []event) []int64 { + offsets := make(map[int64]struct{}) + + for _, e := range events { + if e.UffdioWriteProtect != nil { + offsets[e.offset] = struct{}{} + } + + if e.UffdioCopy != nil { + if e.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + offsets[e.offset] = struct{}{} + } + } + } + + return slices.Collect(maps.Keys(offsets)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_write_protection_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_write_protection_test.go new file mode 100644 index 0000000000..5b987fd95c --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_write_protection_test.go @@ -0,0 +1,320 @@ +package userfaultfd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestWriteProtection(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "standard 4k page, operation at start", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at middle", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at last page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 31 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, writes after reads, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at start", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at middle", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at last page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 7 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, writes after reads, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + for _, operation := range tt.operations { + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint + } + + expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") + }) + } +} + +func TestParallelWriteProtection(t *testing.T) { + t.Parallel() + + parallelOperations := 1_000_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + err = h.executeOperation(t.Context(), readOp) + require.NoError(t, err) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeWrite(t.Context(), writeOp) + }) + } + + err = verr.Wait() + require.NoError(t, err) + + expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") +} + +func TestSerialWriteProtection(t *testing.T) { + t.Parallel() + + serialOperations := 10_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + for range serialOperations { + err := h.executeWrite(t.Context(), writeOp) + require.NoError(t, err) + } + + expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") +} diff --git a/packages/shared/pkg/utils/iter.go b/packages/shared/pkg/utils/iter.go new file mode 100644 index 0000000000..d7197536c0 --- /dev/null +++ b/packages/shared/pkg/utils/iter.go @@ -0,0 +1,13 @@ +package utils + +import "iter" + +func TransformTo[S, T any](iterator iter.Seq[S], f func(S) T) iter.Seq[T] { + return func(yield func(T) bool) { + for v := range iterator { + if !yield(f(v)) { + break + } + } + } +}