diff --git a/.github/actions/start-services/action.yml b/.github/actions/start-services/action.yml index e2c01cd3b6..5365085344 100644 --- a/.github/actions/start-services/action.yml +++ b/.github/actions/start-services/action.yml @@ -101,7 +101,6 @@ runs: STORAGE_PROVIDER: "Local" ENVIRONMENT: "local" OTEL_COLLECTOR_GRPC_ENDPOINT: "localhost:4317" - MAX_PARALLEL_MEMFILE_SNAPSHOTTING: "2" SHARED_CHUNK_CACHE_PATH: "./.e2b-chunk-cache" EDGE_TOKEN: "abdcdefghijklmnop" run: | diff --git a/packages/api/internal/cfg/model.go b/packages/api/internal/cfg/model.go index e5f1b5dfb3..5d72ef706c 100644 --- a/packages/api/internal/cfg/model.go +++ b/packages/api/internal/cfg/model.go @@ -6,7 +6,7 @@ const ( DefaultKernelVersion = "vmlinux-6.1.158" // The Firecracker version the last tag + the short SHA (so we can build our dev previews) // TODO: The short tag here has only 7 characters — the one from our build pipeline will likely have exactly 8 so this will break. - DefaultFirecrackerVersion = "v1.12.1_d990331" + DefaultFirecrackerVersion = "v1.12.1_g8769de7cf" ) type Config struct { diff --git a/packages/orchestrator/Makefile b/packages/orchestrator/Makefile index b620e92026..b1fbc4107d 100644 --- a/packages/orchestrator/Makefile +++ b/packages/orchestrator/Makefile @@ -47,7 +47,6 @@ run-debug: GCP_DOCKER_REPOSITORY_NAME=$(GCP_DOCKER_REPOSITORY_NAME) \ GOOGLE_SERVICE_ACCOUNT_BASE64=$(GOOGLE_SERVICE_ACCOUNT_BASE64) \ OTEL_COLLECTOR_GRPC_ENDPOINT=$(OTEL_COLLECTOR_GRPC_ENDPOINT) \ - MAX_PARALLEL_MEMFILE_SNAPSHOTTING=$(MAX_PARALLEL_MEMFILE_SNAPSHOTTING) \ ./bin/orchestrator define setup_local_env diff --git a/packages/orchestrator/internal/sandbox/diffcreator.go b/packages/orchestrator/internal/sandbox/diffcreator.go index 161ac1ac31..14ddb49278 100644 --- a/packages/orchestrator/internal/sandbox/diffcreator.go +++ b/packages/orchestrator/internal/sandbox/diffcreator.go @@ -3,14 +3,11 @@ package sandbox import ( "context" "errors" - "fmt" "io" - "os" "github.com/bits-and-blooms/bitset" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/rootfs" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) @@ -28,7 +25,7 @@ func (r *RootfsDiffCreator) process(ctx context.Context, out io.Writer) (*header } type MemoryDiffCreator struct { - memfile *storage.TemporaryMemfile + memory io.ReaderAt dirtyPages *bitset.BitSet blockSize int64 doneHook func(context.Context) error @@ -36,21 +33,17 @@ type MemoryDiffCreator struct { func (r *MemoryDiffCreator) process(ctx context.Context, out io.Writer) (h *header.DiffMetadata, e error) { defer func() { - err := r.doneHook(ctx) - if err != nil { - e = errors.Join(e, err) + if r.doneHook != nil { + err := r.doneHook(ctx) + if err != nil { + e = errors.Join(e, err) + } } }() - memfileSource, err := os.Open(r.memfile.Path()) - if err != nil { - return nil, fmt.Errorf("failed to open memfile: %w", err) - } - defer memfileSource.Close() - return header.WriteDiffWithTrace( ctx, - memfileSource, + r.memory, r.blockSize, r.dirtyPages, out, diff --git a/packages/orchestrator/internal/sandbox/fc/client.go b/packages/orchestrator/internal/sandbox/fc/client.go index 33449b5159..566fa56c39 100644 --- a/packages/orchestrator/internal/sandbox/fc/client.go +++ b/packages/orchestrator/internal/sandbox/fc/client.go @@ -126,13 +126,11 @@ func (c *apiClient) pauseVM(ctx context.Context) error { func (c *apiClient) createSnapshot( ctx context.Context, snapfilePath string, - memfilePath string, ) error { snapshotConfig := operations.CreateSnapshotParams{ Context: ctx, Body: &models.SnapshotCreateParams{ SnapshotType: models.SnapshotCreateParamsSnapshotTypeFull, - MemFilePath: &memfilePath, SnapshotPath: &snapfilePath, }, } @@ -301,3 +299,17 @@ func (c *apiClient) startVM(ctx context.Context) error { return nil } + +// vmInfo retrieves general information about an instance from the Firecracker API. +func (c *apiClient) instanceInfo(ctx context.Context) (*models.InstanceInfo, error) { + req := operations.DescribeInstanceParams{ + Context: ctx, + } + + resp, err := c.client.Operations.DescribeInstance(&req) + if err != nil { + return nil, fmt.Errorf("error retrieving vm info: %w", err) + } + + return resp.Payload, nil +} diff --git a/packages/orchestrator/internal/sandbox/fc/memory.go b/packages/orchestrator/internal/sandbox/fc/memory.go new file mode 100644 index 0000000000..170d98390f --- /dev/null +++ b/packages/orchestrator/internal/sandbox/fc/memory.go @@ -0,0 +1,32 @@ +package fc + +import ( + "context" + "fmt" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" +) + +func (p *Process) Memory(ctx context.Context) (*memory.View, error) { + pid, err := p.Pid() + if err != nil { + return nil, fmt.Errorf("failed to get process pid: %w", err) + } + + info, err := p.client.instanceInfo(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get instance info: %w", err) + } + + mapping, err := memory.NewMappingFromFCInfo(info.MemoryRegions) + if err != nil { + return nil, fmt.Errorf("failed to create memory mapping: %w", err) + } + + view, err := memory.NewView(pid, mapping) + if err != nil { + return nil, fmt.Errorf("failed to create memory view: %w", err) + } + + return view, nil +} diff --git a/packages/orchestrator/internal/sandbox/fc/process.go b/packages/orchestrator/internal/sandbox/fc/process.go index 88e919fb83..5f9d66b511 100644 --- a/packages/orchestrator/internal/sandbox/fc/process.go +++ b/packages/orchestrator/internal/sandbox/fc/process.go @@ -487,9 +487,9 @@ func (p *Process) Pause(ctx context.Context) error { } // CreateSnapshot VM needs to be paused before creating a snapshot. -func (p *Process) CreateSnapshot(ctx context.Context, snapfilePath string, memfilePath string) error { +func (p *Process) CreateSnapshot(ctx context.Context, snapfilePath string) error { ctx, childSpan := tracer.Start(ctx, "create-snapshot-fc") defer childSpan.End() - return p.client.createSnapshot(ctx, snapfilePath, memfilePath) + return p.client.createSnapshot(ctx, snapfilePath) } diff --git a/packages/orchestrator/internal/sandbox/sandbox.go b/packages/orchestrator/internal/sandbox/sandbox.go index ea30e3161a..bcdec6c2ff 100644 --- a/packages/orchestrator/internal/sandbox/sandbox.go +++ b/packages/orchestrator/internal/sandbox/sandbox.go @@ -673,10 +673,6 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { return fmt.Errorf("failed to pause VM: %w", err) } - if _, err := s.memory.Disable(ctx); err != nil { - return fmt.Errorf("failed to disable uffd: %w", err) - } - // This is required because the FC API doesn't support passing /dev/null tf, err := storage.TemplateFiles{ BuildID: uuid.New().String(), @@ -688,32 +684,17 @@ func (s *Sandbox) Shutdown(ctx context.Context) error { } defer tf.Close() - // The snapfile is required only because the FC API doesn't support passing /dev/null snapfile := template.NewLocalFileLink(tf.CacheSnapfilePath()) defer snapfile.Close() - // The memfile is required only because the FC API doesn't support passing /dev/null - memfile, err := storage.AcquireTmpMemfile(ctx, s.config, tf.BuildID) - if err != nil { - return fmt.Errorf("failed to acquire memfile snapshot: %w", err) - } - defer memfile.Close() - err = s.process.CreateSnapshot( ctx, snapfile.Path(), - memfile.Path(), ) if err != nil { return fmt.Errorf("error creating snapshot: %w", err) } - // Close the memfile right after the snapshot to release the lock. - err = memfile.Close() - if err != nil { - return fmt.Errorf("error closing memfile: %w", err) - } - // This should properly flush rootfs to the underlying device. err = s.Close(ctx) if err != nil { @@ -753,41 +734,16 @@ func (s *Sandbox) Pause( // Stop the health check before pausing the VM s.Checks.Stop() - if err := s.process.Pause(ctx); err != nil { - return nil, fmt.Errorf("failed to pause VM: %w", err) - } - - // This disables the uffd and returns the dirty pages. - // With FC async io engine, there can be some further writes to the memory during the actual create snapshot process, - // but as we are still including even read pages as dirty so this should not introduce more bugs right now. - dirty, err := s.memory.Disable(ctx) + err = s.process.Pause(ctx) if err != nil { - return nil, fmt.Errorf("failed to get dirty pages: %w", err) + return nil, fmt.Errorf("failed to pause VM: %w", err) } // Snapfile is not closed as it's returned and cached for later use (like resume) snapfile := template.NewLocalFileLink(snapshotTemplateFiles.CacheSnapfilePath()) cleanup.AddNoContext(ctx, snapfile.Close) - // Memfile is also closed on diff creation processing - /* The process of snapshotting memory is as follows: - 1. Pause FC via API - 2. Snapshot FC via API—memory dump to “file on disk” that is actually tmpfs, because it is too slow - 3. Create the diff - copy the diff pages from tmpfs to normal disk file - 4. Delete tmpfs file - 5. Unlock so another snapshot can use tmpfs space - */ - memfile, err := storage.AcquireTmpMemfile(ctx, s.config, buildID.String()) - if err != nil { - return nil, fmt.Errorf("failed to acquire memfile snapshot: %w", err) - } - // Close the file even if an error occurs - defer memfile.Close() - err = s.process.CreateSnapshot( - ctx, - snapfile.Path(), - memfile.Path(), - ) + err = s.process.CreateSnapshot(ctx, snapfile.Path()) if err != nil { return nil, fmt.Errorf("error creating snapshot: %w", err) } @@ -797,23 +753,32 @@ func (s *Sandbox) Pause( if err != nil { return nil, fmt.Errorf("failed to get original memfile: %w", err) } + originalRootfs, err := s.Template.Rootfs() if err != nil { return nil, fmt.Errorf("failed to get original rootfs: %w", err) } + dirty, err := s.memory.Dirty(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get dirty pages: %w", err) + } + + memoryView, err := s.process.Memory(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get memory view: %w", err) + } + defer memoryView.Close() + // Start POSTPROCESSING memfileDiff, memfileDiffHeader, err := pauseProcessMemory( ctx, buildID, originalMemfile.Header(), &MemoryDiffCreator{ - memfile: memfile, + memory: memoryView, dirtyPages: dirty.BitSet(), blockSize: originalMemfile.BlockSize(), - doneHook: func(context.Context) error { - return memfile.Close() - }, }, s.config.DefaultCacheDir, ) diff --git a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go index 688cff5331..f4b20ad5fe 100644 --- a/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go +++ b/packages/orchestrator/internal/sandbox/uffd/fdexit/fdexit.go @@ -7,6 +7,8 @@ import ( "sync" ) +var ErrFdExit = errors.New("fd exit signal") + // FdExit is a wrapper around a pipe that allows to signal the exit of the uffd. type FdExit struct { r *os.File diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go index 824a1e4adb..b487edc7ac 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping.go @@ -2,6 +2,8 @@ package memory import ( "fmt" + + "github.com/e2b-dev/infra/packages/shared/pkg/fc/models" ) type AddressNotFoundError struct { @@ -9,7 +11,15 @@ type AddressNotFoundError struct { } func (e AddressNotFoundError) Error() string { - return fmt.Sprintf("address %d not found in any mapping", e.hostVirtAddr) + return fmt.Sprintf("host virtual address %d not found in any mapping", e.hostVirtAddr) +} + +type OffsetNotFoundError struct { + offset int64 +} + +func (e OffsetNotFoundError) Error() string { + return fmt.Sprintf("offset %d not found in any mapping", e.offset) } type Mapping struct { @@ -20,13 +30,43 @@ func NewMapping(regions []Region) *Mapping { return &Mapping{Regions: regions} } -// GetOffset returns the relative offset and the page size of the mapped range for a given address. -func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uint64, error) { +func NewMappingFromFCInfo(regions []*models.GuestMemoryRegionMapping) (*Mapping, error) { + r := make([]Region, len(regions)) + + for i, infoRegion := range regions { + if infoRegion.BaseHostVirtAddr == nil || infoRegion.Size == nil || infoRegion.Offset == nil || infoRegion.PageSize == nil { + return nil, fmt.Errorf("missing required fields for memory region %d", i) + } + + r[i] = Region{ + BaseHostVirtAddr: uintptr(*infoRegion.BaseHostVirtAddr), + Size: uintptr(*infoRegion.Size), + Offset: uintptr(*infoRegion.Offset), + PageSize: uintptr(*infoRegion.PageSize), + } + } + + return NewMapping(r), nil +} + +// GetOffset returns the relative offset and the pagesize of the mapped range for a given address. +func (m *Mapping) GetOffset(hostVirtAddr uintptr) (int64, uintptr, error) { for _, r := range m.Regions { if hostVirtAddr >= r.BaseHostVirtAddr && hostVirtAddr < r.endHostVirtAddr() { - return r.shiftedOffset(hostVirtAddr), uint64(r.PageSize), nil + return r.shiftedOffset(hostVirtAddr), r.PageSize, nil } } return 0, 0, AddressNotFoundError{hostVirtAddr: hostVirtAddr} } + +// GetHostVirtAddr returns the host virtual address and size of the remaining contiguous mapped host range for the given offset. +func (m *Mapping) GetHostVirtAddr(off int64) (uintptr, int64, error) { + for _, r := range m.Regions { + if off >= int64(r.Offset) && off < r.endOffset() { + return r.shiftedHostVirtAddr(off), r.endOffset() - off, nil + } + } + + return 0, 0, OffsetNotFoundError{offset: off} +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go new file mode 100644 index 0000000000..05e8cc75f5 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_host_virt_test.go @@ -0,0 +1,230 @@ +package memory + +import ( + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestMapping_GetHostVirtAddr(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x5000, + Size: 0x1000, + Offset: 0x8000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + tests := []struct { + name string + offset int64 + expectedHostVirt uintptr + remainingRegionSize int64 + expectError error + }{ + { + name: "valid offset in first region", + offset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedHostVirt: 0x1500, // 0x1000 + (0x5500 - 0x5000) + // region ends at 0x7000; remaining = 0x7000 - 0x5500 = 0x1b00 + remainingRegionSize: 0x1b00, + }, + { + name: "valid offset at start of first region", + offset: 0x5000, + expectedHostVirt: 0x1000, // 0x1000 + (0x5000 - 0x5000) + remainingRegionSize: 0x2000, // 0x7000 - 0x5000 + }, + { + name: "valid offset near end of first region", + offset: 0x6FFF, // 0x7000 - 1 + expectedHostVirt: 0x2FFF, // 0x1000 + (0x6FFF - 0x5000) + remainingRegionSize: 0x1, // 0x7000 - 0x6FFF + }, + { + name: "valid offset at start of second region", + offset: 0x8000, + expectedHostVirt: 0x5000, // 0x5000 + (0x8000 - 0x8000) + remainingRegionSize: 0x1000, // 0x9000 - 0x8000 + }, + { + name: "offset before first region", + offset: 0x4000, + expectError: OffsetNotFoundError{offset: 0x4000}, + }, + { + name: "offset after last region", + offset: 0xA000, + expectError: OffsetNotFoundError{offset: 0xA000}, + }, + { + name: "offset in gap between regions", + offset: 0x7000, + expectError: OffsetNotFoundError{offset: 0x7000}, + }, + { + name: "offset at exact end of first region (exclusive)", + offset: 0x7000, // 0x5000 + 0x2000 + expectError: OffsetNotFoundError{offset: 0x7000}, + }, + { + name: "offset at exact end of second region (exclusive)", + offset: 0x9000, // 0x8000 + 0x1000 + expectError: OffsetNotFoundError{offset: 0x9000}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + hostVirt, size, err := mapping.GetHostVirtAddr(tt.offset) + if tt.expectError != nil { + require.ErrorIs(t, err, tt.expectError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedHostVirt, hostVirt, "hostVirt: %d, expectedHostVirt: %d", hostVirt, tt.expectedHostVirt) + assert.Equal(t, tt.remainingRegionSize, size, "size: %d, expectedSize: %d", size, tt.remainingRegionSize) + } + }) + } +} + +func TestMapping_GetHostVirtAddr_EmptyRegions(t *testing.T) { + t.Parallel() + + mapping := NewMapping([]Region{}) + + // Test GetHostVirtAddr with empty regions + _, _, err := mapping.GetHostVirtAddr(0x1000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1000}) +} + +func TestMapping_GetHostVirtAddr_BoundaryConditions(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x1000, + Size: 0x2000, + Offset: 0x5000, + PageSize: header.PageSize, + }, + } + + mapping := NewMapping(regions) + + // Test exact start boundary + hostVirt, size, err := mapping.GetHostVirtAddr(0x5000) + require.NoError(t, err) + assert.Equal(t, uintptr(0x1000), hostVirt) // 0x1000 + (0x5000 - 0x5000) + assert.Equal(t, int64(0x7000-0x5000), size) // 0x2000 + + // Test offset before end boundary + hostVirt, size, err = mapping.GetHostVirtAddr(0x6FFF) // just before end + require.NoError(t, err) + assert.Equal(t, uintptr(0x1000+(0x6FFF-0x5000)), hostVirt) + assert.Equal(t, int64(0x7000-0x6FFF), size) + + // Test exact end boundary (should fail - exclusive) + _, _, err = mapping.GetHostVirtAddr(0x7000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x7000}) + + // Test below start boundary (should fail) + _, _, err = mapping.GetHostVirtAddr(0x4000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x4000}) +} + +func TestMapping_GetHostVirtAddr_SingleLargeRegion(t *testing.T) { + t.Parallel() + + // Entire 64-bit address space region + regions := []Region{ + { + BaseHostVirtAddr: 0x0, + Size: math.MaxInt64 - 0x100, + Offset: 0x100, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + hostVirt, size, err := mapping.GetHostVirtAddr(0x100 + 0x1000) // Offset 0x1100 + require.NoError(t, err) + assert.Equal(t, uintptr(0x1000), hostVirt) // 0x1000 + assert.Equal(t, int64(math.MaxInt64-0x100-0x1000), size) +} + +func TestMapping_GetHostVirtAddr_ZeroSizeRegion(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x2000, + Size: 0, + Offset: 0x1000, + PageSize: header.PageSize, + }, + } + + mapping := NewMapping(regions) + + _, _, err := mapping.GetHostVirtAddr(0x1000) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1000}) +} + +func TestMapping_GetHostVirtAddr_MultipleRegionsSparse(t *testing.T) { + t.Parallel() + + regions := []Region{ + { + BaseHostVirtAddr: 0x100, + Size: 0x100, + Offset: 0x1000, + PageSize: header.PageSize, + }, + { + BaseHostVirtAddr: 0x10000, + Size: 0x100, + Offset: 0x2000, + PageSize: header.PageSize, + }, + } + mapping := NewMapping(regions) + + // Should succeed for start of first region + hostVirt, size, err := mapping.GetHostVirtAddr(0x1000) + require.NoError(t, err) + assert.Equal(t, uintptr(0x100), hostVirt) // 0x100 + (0x1000 - 0x1000) + assert.Equal(t, int64(0x1100-0x1000), size) // 0x100 + + // Should succeed for just before end of first region + hostVirt, size, err = mapping.GetHostVirtAddr(0x10FF) // 0x1100 - 1 + require.NoError(t, err) + assert.Equal(t, uintptr(0x100+(0x10FF-0x1000)), hostVirt) + assert.Equal(t, int64(0x1100-0x10FF), size) // 1 + + // Should succeed for start of second region + hostVirt, size, err = mapping.GetHostVirtAddr(0x2000) + require.NoError(t, err) + assert.Equal(t, uintptr(0x10000), hostVirt) // 0x10000 + (0x2000 - 0x2000) + assert.Equal(t, int64(0x2100-0x2000), size) // 0x100 + + // In gap + _, _, err = mapping.GetHostVirtAddr(0x1500) + require.ErrorIs(t, err, OffsetNotFoundError{offset: 0x1500}) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_offset_test.go similarity index 50% rename from packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go rename to packages/orchestrator/internal/sandbox/uffd/memory/mapping_offset_test.go index 7b7d87e06f..4be1c1e6e0 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/mapping_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/mapping_offset_test.go @@ -10,6 +10,8 @@ import ( ) func TestMapping_GetOffset(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x1000, @@ -24,50 +26,51 @@ func TestMapping_GetOffset(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) tests := []struct { - name string - hostVirtAddr uintptr - expectedOffset int64 - expectedSize uint64 - expectError error + name string + hostVirtAddr uintptr + expectedOffset int64 + expectedPagesize uintptr + expectError error }{ { - name: "valid address in first region", - hostVirtAddr: 0x1500, - expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) - expectedSize: 0x1000, + name: "valid address in first region", + hostVirtAddr: 0x1500, + expectedOffset: 0x5500, // 0x5000 + (0x1500 - 0x1000) + expectedPagesize: 0x1000, }, { - name: "valid address at start of first region", - hostVirtAddr: 0x1000, - expectedOffset: 0x5000, - expectedSize: 0x1000, + name: "valid address at start of first region", + hostVirtAddr: 0x1000, + expectedOffset: 0x5000, + expectedPagesize: 0x1000, }, { - name: "valid address at end-1 of first region", - hostVirtAddr: 0x2FFF, // 0x1000 + 0x2000 - 1 - expectedOffset: 0x6FFF, // 0x5000 + (0x2FFF - 0x1000) - expectedSize: 0x1000, + name: "valid address at end-1 of first region", + hostVirtAddr: 0x2FFF, // 0x1000 + 0x2000 - 1 + expectedOffset: 0x6FFF, // 0x5000 + (0x2FFF - 0x1000) + expectedPagesize: 0x1000, }, { - name: "valid address in second region", - hostVirtAddr: 0x5500, - expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) - expectedSize: 0x1000, + name: "valid address in second region", + hostVirtAddr: 0x5500, + expectedOffset: 0x8500, // 0x8000 + (0x5500 - 0x5000) + expectedPagesize: 0x1000, }, { - name: "valid address at start of second region", - hostVirtAddr: 0x5000, - expectedOffset: 0x8000, - expectedSize: 0x1000, + name: "valid address at start of second region", + hostVirtAddr: 0x5000, + expectedOffset: 0x8000, + expectedPagesize: 0x1000, }, { - name: "valid address at end-1 of second region", - hostVirtAddr: 0x5FFF, - expectedOffset: 0x8FFF, // 0x8000 + (0x5FFF - 0x5000) - expectedSize: 0x1000, + name: "valid address at end-1 of second region", + hostVirtAddr: 0x5FFF, + expectedOffset: 0x8FFF, // 0x8000 + (0x5FFF - 0x5000) + expectedPagesize: 0x1000, }, { name: "address before first region", @@ -98,60 +101,33 @@ func TestMapping_GetOffset(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - offset, size, err := mapping.GetOffset(tt.hostVirtAddr) + t.Parallel() + + offset, pagesize, err := mapping.GetOffset(tt.hostVirtAddr) if tt.expectError != nil { require.ErrorIs(t, err, tt.expectError) } else { require.NoError(t, err) assert.Equal(t, tt.expectedOffset, offset) - assert.Equal(t, tt.expectedSize, size) + assert.Equal(t, tt.expectedPagesize, pagesize) } }) } } func TestMapping_EmptyRegions(t *testing.T) { + t.Parallel() + mapping := NewMapping([]Region{}) // Test GetOffset with empty regions _, _, err := mapping.GetOffset(0x1000) - require.Error(t, err) -} - -func TestMapping_OverlappingRegions(t *testing.T) { - // Test with overlapping regions (edge case) - regions := []Region{ - { - BaseHostVirtAddr: 0x1000, - Size: 0x2000, - Offset: 0x5000, - PageSize: header.PageSize, - }, - { - BaseHostVirtAddr: 0x2000, // Overlaps with first region - Size: 0x1000, - Offset: 0x8000, - PageSize: header.PageSize, - }, - } - mapping := NewMapping(regions) - - // The first matching region should be returned - offset, size, err := mapping.GetOffset(0x2500) // In overlap area - require.NoError(t, err) - - // Should get result from first region - require.Equal(t, int64(0x5000+(0x2500-0x1000)), offset) // 0x6500 - require.Equal(t, uint64(header.PageSize), size) - - // Also test that the underlying implementation prefers the first region if both regions contain the address - offset2, size2, err2 := mapping.GetOffset(0x2000) - require.NoError(t, err2) - require.Equal(t, int64(0x5000+(0x2000-0x1000)), offset2) // 0x6000 from first region - require.Equal(t, uint64(header.PageSize), size2) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x1000}) } func TestMapping_BoundaryConditions(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x1000, @@ -160,28 +136,33 @@ func TestMapping_BoundaryConditions(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) // Test exact start boundary - offset, _, err := mapping.GetOffset(0x1000) + offset, pagesize, err := mapping.GetOffset(0x1000) require.NoError(t, err) - require.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) + assert.Equal(t, int64(0x5000), offset) // 0x5000 + (0x1000 - 0x1000) + assert.Equal(t, uintptr(header.PageSize), pagesize) // Test just before end boundary (exclusive) - offset, _, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 + offset, pagesize, err = mapping.GetOffset(0x2FFF) // 0x1000 + 0x2000 - 1 require.NoError(t, err) - require.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF + assert.Equal(t, int64(0x5000+(0x2FFF-0x1000)), offset) // 0x6FFF + assert.Equal(t, uintptr(header.PageSize), pagesize) // Test exact end boundary (should fail - exclusive) _, _, err = mapping.GetOffset(0x3000) // 0x1000 + 0x2000 - require.Error(t, err) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x3000}) // Test below start boundary (should fail) - _, _, err = mapping.GetOffset(0x0FFF) - require.Error(t, err) + _, _, err = mapping.GetOffset(0x0FFF) // 0x1000 - 0x1000 + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x0FFF}) } func TestMapping_SingleLargeRegion(t *testing.T) { + t.Parallel() + // Entire 64-bit address space region regions := []Region{ { @@ -193,13 +174,15 @@ func TestMapping_SingleLargeRegion(t *testing.T) { } mapping := NewMapping(regions) - offset, size, err := mapping.GetOffset(0xABCDEF) + offset, pagesize, err := mapping.GetOffset(0xABCDEF) require.NoError(t, err) - require.Equal(t, int64(0x100+0xABCDEF), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x100+0xABCDEF), offset) + assert.Equal(t, uintptr(header.PageSize), pagesize) } func TestMapping_ZeroSizeRegion(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x2000, @@ -208,12 +191,16 @@ func TestMapping_ZeroSizeRegion(t *testing.T) { PageSize: header.PageSize, }, } + mapping := NewMapping(regions) + _, _, err := mapping.GetOffset(0x2000) - require.Error(t, err) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x2000}) } func TestMapping_MultipleRegionsSparse(t *testing.T) { + t.Parallel() + regions := []Region{ { BaseHostVirtAddr: 0x100, @@ -229,19 +216,52 @@ func TestMapping_MultipleRegionsSparse(t *testing.T) { }, } mapping := NewMapping(regions) + // Should succeed for start of first region - offset, size, err := mapping.GetOffset(0x100) + offset, pagesize, err := mapping.GetOffset(0x100) require.NoError(t, err) - require.Equal(t, int64(0x1000), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x1000), offset) + assert.Equal(t, uintptr(header.PageSize), pagesize) // Should succeed for start of second region - offset, size, err = mapping.GetOffset(0x10000) + offset, pagesize, err = mapping.GetOffset(0x10000) require.NoError(t, err) - require.Equal(t, int64(0x2000), offset) - require.Equal(t, uint64(header.PageSize), size) + assert.Equal(t, int64(0x2000), offset) + assert.Equal(t, uintptr(header.PageSize), pagesize) // In gap _, _, err = mapping.GetOffset(0x5000) - require.Error(t, err) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x5000}) +} + +// Additional test for hugepage page size +func TestMapping_HugepagePagesize(t *testing.T) { + t.Parallel() + + const hugepageSize = 2 * 1024 * 1024 // 2MB + regions := []Region{ + { + BaseHostVirtAddr: 0x400000, + Size: hugepageSize, + Offset: 0x800000, + PageSize: hugepageSize, + }, + } + mapping := NewMapping(regions) + + // Test valid address in region using hugepages + offset, pagesize, err := mapping.GetOffset(0x401000) + require.NoError(t, err) + assert.Equal(t, int64(0x800000+(0x401000-0x400000)), offset) + assert.Equal(t, uintptr(hugepageSize), pagesize) + + // Test start of region + offset, pagesize, err = mapping.GetOffset(0x400000) + require.NoError(t, err) + assert.Equal(t, int64(0x800000), offset) + assert.Equal(t, uintptr(hugepageSize), pagesize) + + // Test end of region (exclusive, should fail) + _, _, err = mapping.GetOffset(0x400000 + uintptr(hugepageSize)) + require.ErrorIs(t, err, AddressNotFoundError{hostVirtAddr: 0x400000 + uintptr(hugepageSize)}) } diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/region.go b/packages/orchestrator/internal/sandbox/uffd/memory/region.go index db1d4f8a3c..821f642471 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory/region.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory/region.go @@ -11,6 +11,12 @@ type Region struct { PageSize uintptr `json:"page_size_kib"` // This is actually in bytes in the deprecated version. } +// endOffset returns the end offset of the region in bytes. +// The end offset is exclusive. +func (r *Region) endOffset() int64 { + return int64(r.Offset + r.Size) +} + // endHostVirtAddr returns the end address of the region in host virtual address. // The end address is exclusive. func (r *Region) endHostVirtAddr() uintptr { @@ -21,3 +27,8 @@ func (r *Region) endHostVirtAddr() uintptr { func (r *Region) shiftedOffset(addr uintptr) int64 { return int64(addr - r.BaseHostVirtAddr + r.Offset) } + +// shiftedHostVirtAddr returns the host virtual address of the given offset in the region. +func (r *Region) shiftedHostVirtAddr(off int64) uintptr { + return uintptr(off) + r.BaseHostVirtAddr - r.Offset +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/view.go b/packages/orchestrator/internal/sandbox/uffd/memory/view.go new file mode 100644 index 0000000000..7c0215f1b1 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/view.go @@ -0,0 +1,81 @@ +package memory + +import ( + "errors" + "fmt" + "io" + "os" + "syscall" +) + +var ( + _ io.ReaderAt = (*View)(nil) + _ io.Closer = (*View)(nil) +) + +// MemoryNotFaultedError is returned on read when the page was not faulted in (syscall.EIO error) +type MemoryNotFaultedError struct { + addr uintptr + written int + err error +} + +func (e MemoryNotFaultedError) Error() string { + return fmt.Sprintf("memory not faulted: %v (written %d): %v", e.addr, e.written, e.err) +} + +func (e MemoryNotFaultedError) Unwrap() error { + return e.err +} + +// View exposes memory of the underlying process via offsets, remapped via the mapping to the host virtual address space. +type View struct { + m *Mapping + procMem *os.File + fd int +} + +func NewView(pid int, m *Mapping) (*View, error) { + fd, err := os.Open(fmt.Sprintf("/proc/%d/mem", pid)) + if err != nil { + return nil, fmt.Errorf("failed to open memory file: %w", err) + } + + return &View{ + procMem: fd, + fd: int(fd.Fd()), + m: m, + }, nil +} + +// ReadAt reads data from the memory view at the given offset. +// If this operation crosses a page boundary, it will read the data from the next page. +// +// If you try to read missing pages that are not yet faulted in via UFFD, this will return an error. +func (v *View) ReadAt(d []byte, off int64) (n int, err error) { + for n < len(d) { + addr, size, err := v.m.GetHostVirtAddr(off + int64(n)) + if err != nil { + return n, fmt.Errorf("failed to get host virt addr: %w", err) + } + + remainingSize := min(size, int64(len(d)-n)) + + written, err := syscall.Pread(v.fd, d[n:n+int(remainingSize)], int64(addr)) + if errors.Is(err, syscall.EIO) { + return n, MemoryNotFaultedError{addr: addr, written: written, err: err} + } + + if err != nil { + return n, fmt.Errorf("failed to read from /proc/%d/mem: %w", v.procMem.Name(), err) + } + + n += written + } + + return n, nil +} + +func (v *View) Close() error { + return v.procMem.Close() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory/view_test.go b/packages/orchestrator/internal/sandbox/uffd/memory/view_test.go new file mode 100644 index 0000000000..9fd5d64fa0 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/memory/view_test.go @@ -0,0 +1,261 @@ +package memory + +import ( + "bytes" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" +) + +func TestViewSingleRegionFullRead(t *testing.T) { + t.Parallel() + + pagesize := uint64(4096) + + data := testutils.RandomPages(pagesize, 128) + + size, err := data.Size() + require.NoError(t, err) + + memoryArea, memoryStart, err := testutils.NewPageMmap(t, uint64(size), pagesize) + require.NoError(t, err) + + n := copy(memoryArea[0:size], data.Content()) + require.Equal(t, int(size), n) + + m := NewMapping([]Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(size), + Offset: uintptr(0), + PageSize: uintptr(pagesize), + }, + }) + + pc, err := NewView(os.Getpid(), m) + require.NoError(t, err) + + defer pc.Close() + + for i := 0; i < int(size); i += int(pagesize) { + readBytes := make([]byte, pagesize) + _, err := pc.ReadAt(readBytes, int64(i)) + require.NoError(t, err) + + expectedBytes := data.Content()[i : i+int(pagesize)] + + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error()) + } + } +} + +func TestViewSingleRegionPartialRead(t *testing.T) { + t.Parallel() + + pagesize := uint64(4096) + numberOfPages := uint64(32) + + data := testutils.RandomPages(pagesize, numberOfPages) + + size, err := data.Size() + require.NoError(t, err) + + memoryArea, memoryStart, err := testutils.NewPageMmap(t, uint64(size), pagesize) + require.NoError(t, err) + + n := copy(memoryArea[0:size], data.Content()) + require.Equal(t, int(size), n) + + m := NewMapping([]Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(size), + Offset: uintptr(0), + PageSize: uintptr(pagesize), + }, + }) + + view, err := NewView(os.Getpid(), m) + require.NoError(t, err) + + t.Cleanup(func() { + view.Close() + }) + + t.Run("start of the region", func(t *testing.T) { + t.Parallel() + + readBytes := make([]byte, pagesize) + offset := int64(0) + n, err = view.ReadAt(readBytes, offset) + require.NoError(t, err) + assert.Equal(t, int(pagesize), n) + expectedBytes, err := data.Slice(t.Context(), offset, int64(pagesize)) + require.NoError(t, err) + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", offset) + } + }) + + t.Run("middle of the region", func(t *testing.T) { + t.Parallel() + + readBytes := make([]byte, pagesize) + offset := int64(numberOfPages / 2 * pagesize) + n, err := view.ReadAt(readBytes, offset) + require.NoError(t, err) + assert.Equal(t, int(pagesize), n) + expectedBytes, err := data.Slice(t.Context(), offset, int64(pagesize)) + require.NoError(t, err) + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", offset) + } + }) + + t.Run("end of the region", func(t *testing.T) { + t.Parallel() + + readBytes := make([]byte, pagesize) + offset := int64(numberOfPages*pagesize - pagesize) + n, err := view.ReadAt(readBytes, offset) + require.NoError(t, err) + assert.Equal(t, int(pagesize), n) + expectedBytes, err := data.Slice(t.Context(), offset, int64(pagesize)) + require.NoError(t, err) + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", offset) + } + }) +} + +func TestViewMultipleRegions(t *testing.T) { + t.Parallel() + + pagesize := uint64(4096) + + // Create three separate memory regions with gaps between them + region1Pages := uint64(32) + region2Pages := uint64(64) + region3Pages := uint64(16) + + data1 := testutils.RandomPages(pagesize, region1Pages) + data2 := testutils.RandomPages(pagesize, region2Pages) + data3 := testutils.RandomPages(pagesize, region3Pages) + + size1, err := data1.Size() + require.NoError(t, err) + size2, err := data2.Size() + require.NoError(t, err) + size3, err := data3.Size() + require.NoError(t, err) + + // Create three separate memory mappings + memoryArea1, memoryStart1, err := testutils.NewPageMmap(t, uint64(size1), pagesize) + require.NoError(t, err) + memoryArea2, memoryStart2, err := testutils.NewPageMmap(t, uint64(size2), pagesize) + require.NoError(t, err) + memoryArea3, memoryStart3, err := testutils.NewPageMmap(t, uint64(size3), pagesize) + require.NoError(t, err) + + // Copy data to each region + n1 := copy(memoryArea1[0:size1], data1.Content()) + require.Equal(t, int(size1), n1) + n2 := copy(memoryArea2[0:size2], data2.Content()) + require.Equal(t, int(size2), n2) + n3 := copy(memoryArea3[0:size3], data3.Content()) + require.Equal(t, int(size3), n3) + + // Create mapping with three regions at different offsets + // Region 1: offset 0, size size1 + // Region 2: offset size1 + gap, size size2 + // Region 3: offset size1 + gap + size2 + gap, size size3 + gap := uint64(8192) // 2 pages gap between regions + offset2 := uint64(size1) + gap + offset3 := offset2 + uint64(size2) + gap + + m := NewMapping([]Region{ + { + BaseHostVirtAddr: memoryStart1, + Size: uintptr(size1), + Offset: 0, + PageSize: uintptr(pagesize), + }, + { + BaseHostVirtAddr: memoryStart2, + Size: uintptr(size2), + Offset: uintptr(offset2), + PageSize: uintptr(pagesize), + }, + { + BaseHostVirtAddr: memoryStart3, + Size: uintptr(size3), + Offset: uintptr(offset3), + PageSize: uintptr(pagesize), + }, + }) + + pc, err := NewView(os.Getpid(), m) + require.NoError(t, err) + + defer pc.Close() + + // Test reading from first region + for i := 0; i < int(size1); i += int(pagesize) { + readBytes := make([]byte, pagesize) + _, err := pc.ReadAt(readBytes, int64(i)) + require.NoError(t, err) + + expectedBytes := data1.Content()[i : i+int(pagesize)] + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", i) + } + } + + // Test reading from second region + for i := 0; i < int(size2); i += int(pagesize) { + readBytes := make([]byte, pagesize) + _, err := pc.ReadAt(readBytes, int64(offset2)+int64(i)) + require.NoError(t, err) + + expectedBytes := data2.Content()[i : i+int(pagesize)] + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", int64(offset2)+int64(i)) + } + } + + // Test reading from third region + for i := 0; i < int(size3); i += int(pagesize) { + readBytes := make([]byte, pagesize) + _, err := pc.ReadAt(readBytes, int64(offset3)+int64(i)) + require.NoError(t, err) + + expectedBytes := data3.Content()[i : i+int(pagesize)] + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", int64(offset3)+int64(i)) + } + } + + // Test reading that spans within a single region (not crossing boundaries) + // Read 2 pages from middle of region 2 + readSize := int(2 * pagesize) + readOffset := int64(offset2) + int64(pagesize) + readBytes := make([]byte, readSize) + _, err = pc.ReadAt(readBytes, readOffset) + require.NoError(t, err) + + expectedBytes := data2.Content()[int(pagesize) : int(pagesize)+readSize] + if !bytes.Equal(readBytes, expectedBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "at offset %d", readOffset) + } + + // Test reading that would cross region boundary (should fail at gap) + // Try to read from end of region 1 into gap + readBytes = make([]byte, int(pagesize*2)) + _, err = pc.ReadAt(readBytes, size1-int64(pagesize)) + require.ErrorAs(t, err, &OffsetNotFoundError{offset: size1 - int64(pagesize)}) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go index cf64c83620..dc1463312e 100644 --- a/packages/orchestrator/internal/sandbox/uffd/memory_backend.go +++ b/packages/orchestrator/internal/sandbox/uffd/memory_backend.go @@ -8,9 +8,10 @@ import ( ) type MemoryBackend interface { - // Disable unregisters the uffd from the memory mapping and returns the dirty pages. - // It must be called after FC pause finished and before FC snapshot is created. - Disable(ctx context.Context) (*block.Tracker, error) + // Dirty waits for the current requests to finish and returns the dirty pages. + // + // It *MUST* only be called after the sandbox was successfully paused and the snapshot create endpoint returned as these can still write to the memory. + Dirty(ctx context.Context) (*block.Tracker, error) Start(ctx context.Context, sandboxId string) error Stop() error diff --git a/packages/orchestrator/internal/sandbox/uffd/noop.go b/packages/orchestrator/internal/sandbox/uffd/noop.go index c859c333e4..20483e5bc8 100644 --- a/packages/orchestrator/internal/sandbox/uffd/noop.go +++ b/packages/orchestrator/internal/sandbox/uffd/noop.go @@ -35,7 +35,7 @@ func NewNoopMemory(size, blockSize int64) *NoopMemory { } } -func (m *NoopMemory) Disable(context.Context) (*block.Tracker, error) { +func (m *NoopMemory) Dirty(context.Context) (*block.Tracker, error) { return m.dirty.Clone(), nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go index 68298ea6ea..1acc10a8ed 100644 --- a/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go +++ b/packages/orchestrator/internal/sandbox/uffd/testutils/diff_byte.go @@ -1,20 +1,31 @@ package testutils +import ( + "errors" + "fmt" +) + // FirstDifferentByte returns the first byte index where a and b differ. // It also returns the differing byte values (want, got). // If slices are identical, it returns idx -1. -func FirstDifferentByte(a, b []byte) (idx int, want, got byte) { - smallerSize := min(len(a), len(b)) +func ErrorFromByteSlicesDifference(expected, actual []byte) error { + var errs []error - for i := range smallerSize { - if a[i] != b[i] { - return i, b[i], a[i] - } + if len(expected) > len(actual) { + errs = append(errs, fmt.Errorf("expected slice (%d bytes) is longer than actual slice (%d bytes)", len(expected), len(actual))) + } else if len(expected) < len(actual) { + errs = append(errs, fmt.Errorf("actual slice (%d bytes) is longer than expected slice (%d bytes)", len(actual), len(expected))) } - if len(a) != len(b) { - return smallerSize, 0, 0 + smallerSize := min(len(expected), len(actual)) + + for i := range smallerSize { + if expected[i] != actual[i] { + errs = append(errs, fmt.Errorf("first different byte: want '%x', got '%x' at index %d", expected[i], actual[i], i)) + + break + } } - return -1, 0, 0 + return errors.Join(errs...) } diff --git a/packages/orchestrator/internal/sandbox/uffd/uffd.go b/packages/orchestrator/internal/sandbox/uffd/uffd.go index ce9728f034..ea668b90e6 100644 --- a/packages/orchestrator/internal/sandbox/uffd/uffd.go +++ b/packages/orchestrator/internal/sandbox/uffd/uffd.go @@ -138,7 +138,7 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { m := memory.NewMapping(regions) uffd, err := userfaultfd.NewUserfaultfdFromFd( - uintptr(fds[0]), + userfaultfd.Fd(fds[0]), u.memfile, m, logger.L().With(logger.WithSandboxID(sandboxId)), @@ -162,6 +162,10 @@ func (u *Uffd) handle(ctx context.Context, sandboxId string) error { ctx, u.fdExit, ) + if errors.Is(err, fdexit.ErrFdExit) { + return nil + } + if err != nil { return fmt.Errorf("failed handling uffd: %w", err) } @@ -181,32 +185,10 @@ func (u *Uffd) Exit() *utils.ErrorOnce { return u.exit } -// Disable unregisters the uffd from the memory mapping, -// allowing us to create a "diff" snapshot via FC API without dirty tracking enabled, -// and without pagefaulting all remaining missing pages. -// -// It should be called *after* Dirty(). -// -// After calling Disable(), this uffd is no longer usable—we won't be able to resume the sandbox via API. -// The uffd itself is not closed though, as that should be done by the sandbox cleanup. -func (u *Uffd) Disable(ctx context.Context) (*block.Tracker, error) { - uffd, err := u.handler.WaitWithContext(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get uffd: %w", err) - } - - err = uffd.Unregister() - if err != nil { - return nil, fmt.Errorf("failed to unregister uffd: %w", err) - } - - return u.dirty(ctx) -} - // Dirty waits for the current requests to finish and returns the dirty pages. // // It *MUST* be only called after the sandbox was successfully paused via API. -func (u *Uffd) dirty(ctx context.Context) (*block.Tracker, error) { +func (u *Uffd) Dirty(ctx context.Context) (*block.Tracker, error) { uffd, err := u.handler.WaitWithContext(ctx) if err != nil { return nil, fmt.Errorf("failed to get uffd: %w", err) diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go index 3ccee96802..ef184d126f 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd.go @@ -23,8 +23,6 @@ import ( "fmt" "syscall" "unsafe" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) const ( @@ -34,13 +32,19 @@ const ( UFFD_EVENT_PAGEFAULT = C.UFFD_EVENT_PAGEFAULT UFFDIO_REGISTER_MODE_MISSING = C.UFFDIO_REGISTER_MODE_MISSING + UFFDIO_REGISTER_MODE_WP = C.UFFDIO_REGISTER_MODE_WP + + UFFDIO_WRITEPROTECT_MODE_WP = C.UFFDIO_WRITEPROTECT_MODE_WP + UFFDIO_COPY_MODE_WP = C.UFFDIO_COPY_MODE_WP - UFFDIO_API = C.UFFDIO_API - UFFDIO_REGISTER = C.UFFDIO_REGISTER - UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER - UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_API = C.UFFDIO_API + UFFDIO_REGISTER = C.UFFDIO_REGISTER + UFFDIO_UNREGISTER = C.UFFDIO_UNREGISTER + UFFDIO_COPY = C.UFFDIO_COPY + UFFDIO_WRITEPROTECT = C.UFFDIO_WRITEPROTECT UFFD_PAGEFAULT_FLAG_WRITE = C.UFFD_PAGEFAULT_FLAG_WRITE + UFFD_PAGEFAULT_FLAG_WP = C.UFFD_PAGEFAULT_FLAG_WP UFFD_FEATURE_MISSING_HUGETLBFS = C.UFFD_FEATURE_MISSING_HUGETLBFS ) @@ -81,6 +85,13 @@ func newUffdioRegister(start, length, mode CULong) UffdioRegister { } } +func newUffdioWriteProtect(start, length, mode CULong) UffdioWriteProtect { + return UffdioWriteProtect{ + _range: newUffdioRange(start, length), + mode: mode, + } +} + func newUffdioCopy(b []byte, address CULong, pagesize CULong, mode CULong, bytesCopied CLong) UffdioCopy { return UffdioCopy{ src: CULong(uintptr(unsafe.Pointer(&b[0]))), @@ -103,45 +114,16 @@ func getPagefaultAddress(pagefault *UffdPagefault) uintptr { return uintptr(pagefault.address) } -// uffdFd is a helper type that wraps uffd fd. -type uffdFd uintptr - -// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK -func newFd(flags uintptr) (uffdFd, error) { - uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) - if errno != 0 { - return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) - } - - return uffdFd(uffd), nil -} - -// features: UFFD_FEATURE_MISSING_HUGETLBFS -// This is already called by the FC -func (u uffdFd) configureApi(pagesize uint64) error { - var features CULong - - // Only set the hugepage feature if we're using hugepages - if pagesize == header.HugepageSize { - features |= UFFD_FEATURE_MISSING_HUGETLBFS - } - - api := newUffdioAPI(UFFD_API, features) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_API, uintptr(unsafe.Pointer(&api))) - if errno != 0 { - return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) - } - - return nil -} +// Fd is a helper type that wraps uffd fd. +type Fd uintptr // mode: UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING // This is already called by the FC, but only with the UFFDIO_REGISTER_MODE_MISSING // We need to call it with UFFDIO_REGISTER_MODE_WP when we use both missing and wp -func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { +func (f Fd) register(addr uintptr, size uint64, mode CULong) error { register := newUffdioRegister(CULong(addr), CULong(size), mode) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_REGISTER, uintptr(unsafe.Pointer(®ister))) if errno != 0 { return fmt.Errorf("UFFDIO_REGISTER ioctl failed: %w (ret=%d)", errno, ret) } @@ -149,10 +131,10 @@ func (u uffdFd) register(addr uintptr, size uint64, mode CULong) error { return nil } -func (u uffdFd) unregister(addr uintptr, size uint64) error { +func (f Fd) unregister(addr, size uintptr) error { r := newUffdioRange(CULong(addr), CULong(size)) - ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_UNREGISTER, uintptr(unsafe.Pointer(&r))) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_UNREGISTER, uintptr(unsafe.Pointer(&r))) if errno != 0 { return fmt.Errorf("UFFDIO_UNREGISTER ioctl failed: %w (ret=%d)", errno, ret) } @@ -162,21 +144,38 @@ func (u uffdFd) unregister(addr uintptr, size uint64) error { // mode: UFFDIO_COPY_MODE_WP // When we use both missing and wp, we need to use UFFDIO_COPY_MODE_WP, otherwise copying would unprotect the page -func (u uffdFd) copy(addr uintptr, data []byte, pagesize uint64, mode CULong) error { +func (f Fd) copy(addr, pagesize uintptr, data []byte, mode CULong) error { cpy := newUffdioCopy(data, CULong(addr)&^CULong(pagesize-1), CULong(pagesize), mode, 0) - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(u), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_COPY, uintptr(unsafe.Pointer(&cpy))); errno != 0 { return errno } // Check if the copied size matches the requested pagesize - if uint64(cpy.copy) != pagesize { + if cpy.copy != CLong(pagesize) { return fmt.Errorf("UFFDIO_COPY copied %d bytes, expected %d", cpy.copy, pagesize) } return nil } -func (u uffdFd) close() error { - return syscall.Close(int(u)) +// mode: UFFDIO_WRITEPROTECT_MODE_WP +// Passing 0 as the mode will remove the write protection. +func (f Fd) writeProtect(addr, size uintptr, mode CULong) error { + register := newUffdioWriteProtect(CULong(addr), CULong(size), mode) + + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_WRITEPROTECT, uintptr(unsafe.Pointer(®ister))) + if errno != 0 { + return fmt.Errorf("UFFDIO_WRITEPROTECT ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +func (f Fd) close() error { + return syscall.Close(int(f)) +} + +func (f Fd) fd() int32 { + return int32(f) } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go new file mode 100644 index 0000000000..367af9e2aa --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/fd_helpers_test.go @@ -0,0 +1,122 @@ +package userfaultfd + +import ( + "fmt" + "syscall" + "unsafe" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// mockFd is a mock implementation of the Fd interface. +// It allows us to test the handling methods separately from the actual uffd serve loop. +type mockFd struct { + // The channels send back the info about the uffd handled operations + // and also allows us to block the methods to test the flow. + copyCh chan *blockedEvent[UffdioCopy] + writeProtectCh chan *blockedEvent[UffdioWriteProtect] +} + +func newMockFd() *mockFd { + return &mockFd{ + copyCh: make(chan *blockedEvent[UffdioCopy]), + writeProtectCh: make(chan *blockedEvent[UffdioWriteProtect]), + } +} + +func (m *mockFd) register(_ uintptr, _ uint64, _ CULong) error { + return nil +} + +func (m *mockFd) unregister(_, _ uintptr) error { + return nil +} + +func (m *mockFd) copy(addr, pagesize uintptr, _ []byte, mode CULong) error { + // Don't use the uffdioCopy constructor as it unsafely checks slice address and fails for arbitrary pointer. + e := newBlockedEvent(UffdioCopy{ + src: 0, + dst: CULong(addr), + len: CULong(pagesize), + mode: mode, + copy: 0, + }) + + m.copyCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) writeProtect(addr, size uintptr, mode CULong) error { + e := newBlockedEvent(UffdioWriteProtect{ + _range: newUffdioRange( + CULong(addr), + CULong(size), + ), + mode: mode, + }) + + m.writeProtectCh <- e + + <-e.resolved + + return nil +} + +func (m *mockFd) close() error { + return nil +} + +func (m *mockFd) fd() int32 { + return 0 +} + +// Used for testing. +// flags: syscall.O_CLOEXEC|syscall.O_NONBLOCK +func newFd(flags uintptr) (Fd, error) { + uffd, _, errno := syscall.Syscall(NR_userfaultfd, flags, 0, 0) + if errno != 0 { + return 0, fmt.Errorf("userfaultfd syscall failed: %w", errno) + } + + return Fd(uffd), nil +} + +// Used for testing +// features: UFFD_FEATURE_MISSING_HUGETLBFS +// This is already called by the FC +func configureApi(f Fd, pagesize uint64) error { + var features CULong + + // Only set the hugepage feature if we're using hugepages + if pagesize == header.HugepageSize { + features |= UFFD_FEATURE_MISSING_HUGETLBFS + } + + api := newUffdioAPI(UFFD_API, features) + ret, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f), UFFDIO_API, uintptr(unsafe.Pointer(&api))) + if errno != 0 { + return fmt.Errorf("UFFDIO_API ioctl failed: %w (ret=%d)", errno, ret) + } + + return nil +} + +// This wrapped event allows us to simulate the finish of processing of events by FC on FC API Pause. +type blockedEvent[T UffdioCopy | UffdioWriteProtect] struct { + event T + resolved chan struct{} +} + +func newBlockedEvent[T UffdioCopy | UffdioWriteProtect](event T) *blockedEvent[T] { + return &blockedEvent[T]{ + event: event, + resolved: make(chan struct{}), + } +} + +func (e *blockedEvent[T]) resolve() { + close(e.resolved) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go index 2a1397e534..90c1b6242f 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd.go @@ -1,5 +1,10 @@ package userfaultfd +// flowchart TD +// A[missing page] -- write (WRITE flag) --> B(COPY) --> C[dirty page] +// A -- read (MISSING flag) --> D(COPY + MODE_WP) --> E[faulted page] +// E -- write (WP+[WRITE] flag) --> F(remove MODE_WP) --> C + import ( "context" "errors" @@ -22,15 +27,25 @@ const maxRequestsInProgress = 4096 var ErrUnexpectedEventType = errors.New("unexpected event type") +type uffdio interface { + unregister(addr, size uintptr) error + register(addr uintptr, size uint64, mode CULong) error + copy(addr, pagesize uintptr, data []byte, mode CULong) error + writeProtect(addr, size uintptr, mode CULong) error + close() error + fd() int32 +} + type Userfaultfd struct { - fd uffdFd + uffd uffdio src block.Slicer - ma *memory.Mapping + m *memory.Mapping // We don't skip the already mapped pages, because if the memory is swappable the page *might* under some conditions be mapped out. // For hugepages this should not be a problem, but might theoretically happen to normal pages with swap missingRequests *block.Tracker + writeRequests *block.Tracker // We use the settleRequests to guard the missingRequests so we can access a consistent state of the missingRequests after the requests are finished. settleRequests sync.RWMutex @@ -40,20 +55,34 @@ type Userfaultfd struct { } // NewUserfaultfdFromFd creates a new userfaultfd instance with optional configuration. -func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { +func NewUserfaultfdFromFd(uffd uffdio, src block.Slicer, m *memory.Mapping, logger logger.Logger) (*Userfaultfd, error) { blockSize := src.BlockSize() for _, region := range m.Regions { if region.PageSize != uintptr(blockSize) { return nil, fmt.Errorf("block size mismatch: %d != %d for region %d", region.PageSize, blockSize, region.BaseHostVirtAddr) } + + // Register the WP for the regions. + // The memory region is already registered (with missing pages in FC), but registering it again with bigger flag subset should merge these registration flags. + // - https://github.com/firecracker-microvm/firecracker/blob/f335a0adf46f0680a141eb1e76fe31ac258918c5/src/vmm/src/persist.rs#L477 + // - https://github.com/bytecodealliance/userfaultfd-rs/blob/main/src/builder.rs + err := uffd.register( + region.BaseHostVirtAddr, + uint64(region.Size), + UFFDIO_REGISTER_MODE_WP|UFFDIO_REGISTER_MODE_MISSING, + ) + if err != nil { + return nil, fmt.Errorf("failed to reregister memory region with write protection %d-%d: %w", region.Offset, region.Offset+region.Size, err) + } } u := &Userfaultfd{ - fd: uffdFd(fd), + uffd: uffd, src: src, missingRequests: block.NewTracker(blockSize), - ma: m, + writeRequests: block.NewTracker(blockSize), + m: m, logger: logger, } @@ -66,15 +95,17 @@ func NewUserfaultfdFromFd(fd uintptr, src block.Slicer, m *memory.Mapping, logge } func (u *Userfaultfd) Close() error { - return u.fd.close() + return u.uffd.close() } func (u *Userfaultfd) Serve( ctx context.Context, fdExit *fdexit.FdExit, ) error { + uffd := u.uffd.fd() + pollFds := []unix.PollFd{ - {Fd: int32(u.fd), Events: unix.POLLIN}, + {Fd: uffd, Events: unix.POLLIN}, {Fd: fdExit.Reader(), Events: unix.POLLIN}, } @@ -113,7 +144,7 @@ outerLoop: return fmt.Errorf("failed to handle uffd: %w", errMsg) } - return nil + return fdexit.ErrFdExit } uffdFd := pollFds[0] @@ -135,7 +166,7 @@ outerLoop: buf := make([]byte, unsafe.Sizeof(UffdMsg{})) for { - _, err := syscall.Read(int(u.fd), buf) + _, err := syscall.Read(int(uffd), buf) if err == syscall.EINTR { u.logger.Debug(ctx, "uffd: interrupted read, reading again") @@ -176,21 +207,27 @@ outerLoop: addr := getPagefaultAddress(&pagefault) - offset, pagesize, err := u.ma.GetOffset(addr) + offset, pagesize, err := u.m.GetOffset(addr) if err != nil { u.logger.Error(ctx, "UFFD serve get mapping error", zap.Error(err)) return fmt.Errorf("failed to map: %w", err) } + // Handle write to write protected page (WP flag) + // The documentation does not clearly mention if the WRITE flag must be present with the WP flag, even though we saw it being present in the events. + // - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications + if flags&UFFD_PAGEFAULT_FLAG_WP != 0 { + u.handleWriteProtected(ctx, fdExit.SignalExit, addr, pagesize, offset) + + continue + } + // Handle write to missing page (WRITE flag) // If the event has WRITE flag, it was a write to a missing page. // For the write to be executed, we first need to copy the page from the source to the guest memory. if flags&UFFD_PAGEFAULT_FLAG_WRITE != 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) - if err != nil { - return fmt.Errorf("failed to handle missing write: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, true) continue } @@ -198,10 +235,7 @@ outerLoop: // Handle read to missing page ("MISSING" flag) // If the event has no flags, it was a read to a missing page and we need to copy the page from the source to the guest memory. if flags == 0 { - err := u.handleMissing(ctx, fdExit.SignalExit, addr, offset, pagesize) - if err != nil { - return fmt.Errorf("failed to handle missing: %w", err) - } + u.handleMissing(ctx, fdExit.SignalExit, addr, pagesize, offset, false) continue } @@ -214,21 +248,27 @@ outerLoop: func (u *Userfaultfd) handleMissing( ctx context.Context, onFailure func() error, - addr uintptr, + addr, + pagesize uintptr, offset int64, - pagesize uint64, -) error { + write bool, +) { u.wg.Go(func() error { // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, // even if the errgroup is cancelled or the goroutine returns early. // This check protects us against race condition between marking the request as missing and accessing the missingRequests tracker. - // The Firecracker pause should return only after the requested memory is faulted in, so we don't need to guard the pagefault from the moment it is created. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. u.settleRequests.RLock() defer u.settleRequests.RUnlock() defer func() { if r := recover(); r != nil { u.logger.Error(ctx, "UFFD serve panic", zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle missing failure error", zap.Error(signalErr)) + } } }() @@ -245,7 +285,12 @@ func (u *Userfaultfd) handleMissing( var copyMode CULong - copyErr := u.fd.copy(addr, b, pagesize, copyMode) + // If the event is not WRITE, we need to add WP to the page, so we can catch the next WRITE+WP and mark the page as dirty. + if !write { + copyMode |= UFFDIO_COPY_MODE_WP + } + + copyErr := u.uffd.copy(addr, pagesize, b, copyMode) if errors.Is(copyErr, unix.EEXIST) { // Page is already mapped @@ -259,28 +304,64 @@ func (u *Userfaultfd) handleMissing( u.logger.Error(ctx, "UFFD serve uffdio copy error", zap.Error(joinedErr)) - return fmt.Errorf("failed uffdio copy %w", joinedErr) + return fmt.Errorf("failed to copy page %d-%d %w", offset, offset+int64(pagesize), joinedErr) } // Add the offset to the missing requests tracker. u.missingRequests.Add(offset) + if write { + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + } + return nil }) - - return nil } -func (u *Userfaultfd) Unregister() error { - for _, r := range u.ma.Regions { - if err := u.fd.unregister(r.BaseHostVirtAddr, uint64(r.Size)); err != nil { - return fmt.Errorf("failed to unregister: %w", err) +// Userfaultfd write-protect mode currently behave differently on none ptes (when e.g. page is missing) over different types of memories (hugepages file backed, etc.). +// - https://docs.kernel.org/admin-guide/mm/userfaultfd.html#write-protect-notifications - "there will be a userfaultfd write fault message generated when writing to a missing page" +// This should not affect the handling we have in place as all events are being handled. +func (u *Userfaultfd) handleWriteProtected(ctx context.Context, onFailure func() error, addr, pagesize uintptr, offset int64) { + u.wg.Go(func() error { + // The RLock must be called inside the goroutine to ensure RUnlock runs via defer, + // even if the errgroup is cancelled or the goroutine returns early. + // This check protects us against race condition between marking the request as dirty and accessing the writeRequests tracker. + // The Firecracker pause should return only after the requested memory is copied to the guest memory, so we don't need to guard the pagefault from the moment it is created. + u.settleRequests.RLock() + defer u.settleRequests.RUnlock() + + defer func() { + if r := recover(); r != nil { + u.logger.Error(ctx, "UFFD remove write protection panic", zap.Any("offset", offset), zap.Any("pagesize", pagesize), zap.Any("panic", r)) + + signalErr := onFailure() + if signalErr != nil { + u.logger.Error(ctx, "UFFD handle write protected failure error", zap.Error(signalErr)) + } + } + }() + + // Passing 0 as the mode removes the write protection. + wpErr := u.uffd.writeProtect(addr, pagesize, 0) + if wpErr != nil { + signalErr := onFailure() + + joinedErr := errors.Join(wpErr, signalErr) + + u.logger.Error(ctx, "UFFD serve write protect error", zap.Error(joinedErr)) + + return fmt.Errorf("failed to remove write protection from page %d-%d %w", offset, offset+int64(pagesize), joinedErr) } - } - return nil + // Add the offset to the write requests tracker. + u.writeRequests.Add(offset) + + return nil + }) } +// Dirty returns the dirty pages. func (u *Userfaultfd) Dirty() *block.Tracker { // This will be at worst cancelled when the uffd is closed. u.settleRequests.Lock() @@ -288,5 +369,5 @@ func (u *Userfaultfd) Dirty() *block.Tracker { // so it is consistent even if there is a another uffd call after. defer u.settleRequests.Unlock() - return u.missingRequests.Clone() + return u.writeRequests.Clone() } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_helpers_test.go similarity index 68% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_helpers_test.go index 1a65664b86..8631457a49 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_helpers_test.go @@ -9,6 +9,8 @@ import ( "github.com/bits-and-blooms/bitset" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" ) @@ -41,8 +43,26 @@ type testHandler struct { data *testutils.MemorySlicer // Returns offsets of the pages that were faulted. // It can only be called once. - offsetsOnce func() ([]uint, error) - mutex sync.Mutex + // Sorted in ascending order. + accessedOffsetsOnce func() ([]uint, error) + // Returns offsets of the pages that were dirtied. + // It can only be called once. + // Sorted in ascending order. + dirtyOffsetsOnce func() ([]uint, error) + + mutex sync.Mutex + mapping *memory.Mapping +} + +func (h *testHandler) executeOperation(ctx context.Context, op operation) error { + switch op.mode { + case operationModeRead: + return h.executeRead(ctx, op) + case operationModeWrite: + return h.executeWrite(ctx, op) + default: + return fmt.Errorf("invalid operation mode: %d", op.mode) + } } func (h *testHandler) executeRead(ctx context.Context, op operation) error { @@ -55,9 +75,7 @@ func (h *testHandler) executeRead(ctx context.Context, op operation) error { // The bytes.Equal is the first place in this flow that actually touches the uffd managed memory and triggers the pagefault, so any deadlocks will manifest here. if !bytes.Equal(readBytes, expectedBytes) { - idx, want, got := testutils.FirstDifferentByte(readBytes, expectedBytes) - - return fmt.Errorf("content mismatch: want '%x, got %x at index %d", want, got, idx) + return fmt.Errorf("content mismatch: %w", testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes)) } return nil @@ -82,6 +100,7 @@ func (h *testHandler) executeWrite(ctx context.Context, op operation) error { } // Get a bitset of the offsets of the operations for the given mode. +// Sorted in ascending order. func getOperationsOffsets(ops []operation, m operationMode) []uint { b := bitset.New(0) @@ -93,3 +112,10 @@ func getOperationsOffsets(ops []operation, m operationMode) []uint { return slices.Collect(b.EachSet()) } + +func accessed(u *Userfaultfd) *block.Tracker { + u.settleRequests.Lock() + defer u.settleRequests.Unlock() + + return u.missingRequests.Clone() +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_memory_view_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_memory_view_test.go new file mode 100644 index 0000000000..aa26f6d0f6 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_memory_view_test.go @@ -0,0 +1,276 @@ +package userfaultfd + +// The tests for memory.View are reading memory from the same process the memory belongs to, but with the /proc/PID/mem file it should not matter. + +import ( + "bytes" + "os" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/testutils" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestUffdMemoryViewFaulted(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "standard 4k page, operation at start, read faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, operation at middle, read faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at start, read faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, operation at middle, read faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, operation at start, write faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at middle, write faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at start, write faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at middle, write faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + for _, operation := range tt.operations { + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint + } + + view, err := memory.NewView(os.Getpid(), h.mapping) + require.NoError(t, err) + + for _, operation := range tt.operations { + readBytes := make([]byte, tt.pagesize) + n, err := view.ReadAt(readBytes, operation.offset) + require.NoError(t, err) + assert.Len(t, readBytes, n) + + expectedBytes, err := h.data.Slice(t.Context(), operation.offset, int64(tt.pagesize)) + require.NoError(t, err) + + if !bytes.Equal(expectedBytes, readBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "for operation %+v", operation) + } + } + }) + } +} + +func TestUffdMemoryViewNotFaultedError(t *testing.T) { + t.Parallel() + + test := testConfig{ + name: "standard 4k page, operation at start", + pagesize: header.PageSize, + numberOfPages: 32, + } + + h, err := configureCrossProcessTest(t, test) + require.NoError(t, err) + + view, err := memory.NewView(os.Getpid(), h.mapping) + require.NoError(t, err) + + readBytes := make([]byte, header.PageSize) + _, err = view.ReadAt(readBytes, 0) + require.ErrorAs(t, err, &memory.MemoryNotFaultedError{}) + require.ErrorIs(t, err, syscall.EIO) +} + +func TestUffdMemoryViewDirty(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "standard 4k page, operation at start, write faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at middle, write faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at end, write faulted", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at start, write faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at middle, write faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at end, write faulted", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + writeData := testutils.RandomPages(tt.pagesize, tt.numberOfPages) + + view, err := memory.NewView(os.Getpid(), h.mapping) + require.NoError(t, err) + + for _, op := range tt.operations { + // An unprotected parallel write to map might result in an undefined behavior. + h.mutex.Lock() + + data, err := writeData.Slice(t.Context(), op.offset, int64(h.pagesize)) + require.NoError(t, err) + // We explicitly write to the memory area to make it differ from the default served content. + n := copy((*h.memoryArea)[op.offset:op.offset+int64(h.pagesize)], data) + h.mutex.Unlock() + + assert.Equal(t, int(h.pagesize), n, "copy length mismatch for operation %+v", op) + + readBytes := make([]byte, tt.pagesize) + n, err = view.ReadAt(readBytes, op.offset) + require.NoError(t, err) + assert.Len(t, readBytes, n) + + expectedBytes, err := writeData.Slice(t.Context(), op.offset, int64(tt.pagesize)) + require.NoError(t, err) + + if !bytes.Equal(expectedBytes, readBytes) { + assert.Fail(t, testutils.ErrorFromByteSlicesDifference(expectedBytes, readBytes).Error(), "for operation %+v", op) + } + } + }) + } +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_test.go similarity index 73% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_test.go index 20c8ddeeb3..cf24e55d51 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_test.go @@ -62,6 +62,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "standard 4k page, reads, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + }, + }, { name: "hugepage, operation at start", pagesize: header.HugepageSize, @@ -110,6 +137,33 @@ func TestMissing(t *testing.T) { }, }, }, + { + name: "hugepage, reads, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, } for _, tt := range tests { @@ -120,15 +174,13 @@ func TestMissing(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -158,7 +210,7 @@ func TestParallelMissing(t *testing.T) { for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -167,7 +219,7 @@ func TestParallelMissing(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -191,14 +243,14 @@ func TestParallelMissingWithPrefault(t *testing.T) { mode: operationModeRead, } - err = h.executeRead(t.Context(), readOp) + err = h.executeOperation(t.Context(), readOp) require.NoError(t, err) var verr errgroup.Group for range parallelOperations { verr.Go(func() error { - return h.executeRead(t.Context(), readOp) + return h.executeOperation(t.Context(), readOp) }) } @@ -207,7 +259,7 @@ func TestParallelMissingWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") @@ -232,13 +284,13 @@ func TestSerialMissing(t *testing.T) { } for range serialOperations { - err := h.executeRead(t.Context(), readOp) + err := h.executeOperation(t.Context(), readOp) require.NoError(t, err) } expectedAccessedOffsets := getOperationsOffsets([]operation{readOp}, operationModeRead) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_write_test.go similarity index 55% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_write_test.go index 52a73e5fa2..7505dcc6d8 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/missing_write_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_missing_write_test.go @@ -10,6 +10,8 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// TODO: Investigate flakyness +// TODO: It is possible the hugepages trigger the automatic WP func TestMissingWrite(t *testing.T) { t.Parallel() @@ -62,6 +64,68 @@ func TestMissingWrite(t *testing.T) { }, }, }, + { + name: "standard 4k page, write after write", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "standard 4k page, reads after writes, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + }, + }, { name: "hugepage, operation at start", pagesize: header.HugepageSize, @@ -110,6 +174,68 @@ func TestMissingWrite(t *testing.T) { }, }, }, + { + name: "hugepage, write after write", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeWrite, + }, + { + offset: 0, + mode: operationModeRead, + }, + }, + }, + { + name: "hugepage, reads after writes, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + }, + }, } for _, tt := range tests { @@ -120,23 +246,23 @@ func TestMissingWrite(t *testing.T) { require.NoError(t, err) for _, operation := range tt.operations { - if operation.mode == operationModeRead { - err := h.executeRead(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } - - if operation.mode == operationModeWrite { - err := h.executeWrite(t.Context(), operation) - require.NoError(t, err, "for operation %+v", operation) - } + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint } expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") }) } } @@ -172,10 +298,17 @@ func TestParallelMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } func TestParallelMissingWriteWithPrefault(t *testing.T) { @@ -212,10 +345,17 @@ func TestParallelMissingWriteWithPrefault(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } func TestSerialMissingWrite(t *testing.T) { @@ -243,8 +383,15 @@ func TestSerialMissingWrite(t *testing.T) { expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) - accessedOffsets, err := h.offsetsOnce() + accessedOffsets, err := h.accessedOffsetsOnce() require.NoError(t, err) assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_process_helpers_test.go similarity index 61% rename from packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go rename to packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_process_helpers_test.go index be12f71a1f..81ddfbf6ae 100644 --- a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/cross_process_helpers_test.go +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_process_helpers_test.go @@ -16,7 +16,6 @@ import ( "os/exec" "os/signal" "strconv" - "strings" "syscall" "testing" @@ -50,13 +49,15 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error uffdFd.close() }) - err = uffdFd.configureApi(tt.pagesize) + err = configureApi(uffdFd, tt.pagesize) require.NoError(t, err) err = uffdFd.register(memoryStart, uint64(size), UFFDIO_REGISTER_MODE_MISSING) require.NoError(t, err) - cmd := exec.CommandContext(t.Context(), os.Args[0], "-test.run=TestHelperServingProcess") + // We don't use t.Context() here, because we want to be able to kill the process manually and listen to the correct exit code, + // while also handling the cleanup of the uffd. The t.Context seems to trigger before the test cleanup is started. + cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperServingProcess") cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_START=%d", memoryStart)) cmd.Env = append(cmd.Env, fmt.Sprintf("GO_MMAP_PAGE_SIZE=%d", tt.pagesize)) @@ -81,11 +82,18 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error assert.NoError(t, closeErr) }() - offsetsReader, offsetsWriter, err := os.Pipe() + accessedOffsetsReader, accessedOffsetsWriter, err := os.Pipe() require.NoError(t, err) t.Cleanup(func() { - offsetsReader.Close() + accessedOffsetsReader.Close() + }) + + dirtyOffsetsReader, dirtyOffsetsWriter, err := os.Pipe() + require.NoError(t, err) + + t.Cleanup(func() { + dirtyOffsetsReader.Close() }) readyReader, readyWriter, err := os.Pipe() @@ -106,8 +114,9 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error cmd.ExtraFiles = []*os.File{ uffdFile, contentReader, - offsetsWriter, + accessedOffsetsWriter, readyWriter, + dirtyOffsetsWriter, } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr @@ -116,36 +125,59 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error require.NoError(t, err) contentReader.Close() - offsetsWriter.Close() + accessedOffsetsWriter.Close() readyWriter.Close() uffdFile.Close() + dirtyOffsetsWriter.Close() + + go func() { + waitErr := cmd.Wait() + assert.NoError(t, waitErr) + + assert.NotEqual(t, -1, cmd.ProcessState.ExitCode(), "process was not terminated gracefully") + assert.NotEqual(t, 2, cmd.ProcessState.ExitCode(), "fd exit prematurely terminated the serve loop") + assert.NotEqual(t, 1, cmd.ProcessState.ExitCode(), "process exited with unexpected exit code") + + assert.Equal(t, 0, cmd.ProcessState.ExitCode()) + }() t.Cleanup(func() { - signalErr := cmd.Process.Signal(syscall.SIGUSR1) + // We are using SIGHUP to actually get exit code, not -1. + signalErr := cmd.Process.Signal(syscall.SIGTERM) assert.NoError(t, signalErr) - - waitErr := cmd.Wait() - // It can be either nil, an ExitError, a context.Canceled error, or "signal: killed" - assert.True(t, - (waitErr != nil && func(err error) bool { - var exitErr *exec.ExitError - - return errors.As(err, &exitErr) - }(waitErr)) || - errors.Is(waitErr, context.Canceled) || - (waitErr != nil && strings.Contains(waitErr.Error(), "signal: killed")) || - waitErr == nil, - "unexpected error: %v", waitErr, - ) }) - offsetsOnce := func() ([]uint, error) { + accessedOffsetsOnce := func() ([]uint, error) { err := cmd.Process.Signal(syscall.SIGUSR2) if err != nil { return nil, err } - offsetsBytes, err := io.ReadAll(offsetsReader) + offsetsBytes, err := io.ReadAll(accessedOffsetsReader) + if err != nil { + return nil, err + } + + var offsetList []uint + + if len(offsetsBytes)%8 != 0 { + return nil, fmt.Errorf("invalid offsets bytes length: %d", len(offsetsBytes)) + } + + for i := 0; i < len(offsetsBytes); i += 8 { + offsetList = append(offsetList, uint(binary.LittleEndian.Uint64(offsetsBytes[i:i+8]))) + } + + return offsetList, nil + } + + dirtyOffsetsOnce := func() ([]uint, error) { + err := cmd.Process.Signal(syscall.SIGUSR1) + if err != nil { + return nil, err + } + + offsetsBytes, err := io.ReadAll(dirtyOffsetsReader) if err != nil { return nil, err } @@ -169,23 +201,38 @@ func configureCrossProcessTest(t *testing.T, tt testConfig) (*testHandler, error case <-readySignal: } + mapping := memory.NewMapping([]memory.Region{ + { + BaseHostVirtAddr: memoryStart, + Size: uintptr(size), + Offset: 0, + PageSize: uintptr(tt.pagesize), + }, + }) + return &testHandler{ - memoryArea: &memoryArea, - pagesize: tt.pagesize, - data: data, - offsetsOnce: offsetsOnce, + memoryArea: &memoryArea, + pagesize: tt.pagesize, + data: data, + accessedOffsetsOnce: accessedOffsetsOnce, + mapping: mapping, + dirtyOffsetsOnce: dirtyOffsetsOnce, }, nil } -// Secondary process, orchestrator in our case +// Secondary process, orchestrator in our case. func TestHelperServingProcess(t *testing.T) { if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { t.Skip("this is a helper process, skipping direct execution") } err := crossProcessServe() + if errors.Is(err, fdexit.ErrFdExit) { + os.Exit(2) + } + if err != nil { - fmt.Println("exit serving process", err) + fmt.Fprintf(os.Stderr, "error serving: %v", err) os.Exit(1) } @@ -240,29 +287,61 @@ func crossProcessServe() error { return fmt.Errorf("exit creating logger: %w", err) } - uffd, err := NewUserfaultfdFromFd(uffdFd, data, m, l) + uffd, err := NewUserfaultfdFromFd(Fd(uffdFd), data, m, l) if err != nil { return fmt.Errorf("exit creating uffd: %w", err) } - offsetsFile := os.NewFile(uintptr(5), "offsets") + accessedOffsetsFile := os.NewFile(uintptr(5), "accessed-offsets") - offsetsSignal := make(chan os.Signal, 1) - signal.Notify(offsetsSignal, syscall.SIGUSR2) - defer signal.Stop(offsetsSignal) + accessedOffsestsSignal := make(chan os.Signal, 1) + signal.Notify(accessedOffsestsSignal, syscall.SIGUSR2) + defer signal.Stop(accessedOffsestsSignal) go func() { - defer offsetsFile.Close() + defer accessedOffsetsFile.Close() for { select { case <-ctx.Done(): return - case <-offsetsSignal: + case <-accessedOffsestsSignal: + for offset := range accessed(uffd).Offsets() { + writeErr := binary.Write(accessedOffsetsFile, binary.LittleEndian, uint64(offset)) + if writeErr != nil { + msg := fmt.Errorf("error writing accessed offsets to file: %w", writeErr) + + fmt.Fprint(os.Stderr, msg.Error()) + + cancel(msg) + + return + } + } + + return + } + } + }() + + dirtyOffsetsFile := os.NewFile(uintptr(7), "dirty-offsets") + + dirtyOffsetsSignal := make(chan os.Signal, 1) + signal.Notify(dirtyOffsetsSignal, syscall.SIGUSR1) + defer signal.Stop(dirtyOffsetsSignal) + + go func() { + defer dirtyOffsetsFile.Close() + + for { + select { + case <-ctx.Done(): + return + case <-dirtyOffsetsSignal: for offset := range uffd.Dirty().Offsets() { - writeErr := binary.Write(offsetsFile, binary.LittleEndian, uint64(offset)) + writeErr := binary.Write(dirtyOffsetsFile, binary.LittleEndian, uint64(offset)) if writeErr != nil { - msg := fmt.Errorf("error writing offsets to file: %w", writeErr) + msg := fmt.Errorf("error writing dirty offsets to file: %w", writeErr) fmt.Fprint(os.Stderr, msg.Error()) @@ -289,6 +368,14 @@ func crossProcessServe() error { }() serverErr := uffd.Serve(ctx, fdExit) + if errors.Is(serverErr, fdexit.ErrFdExit) { + err := fmt.Errorf("serving finished via fd exit: %w", serverErr) + + cancel(err) + + return + } + if serverErr != nil { msg := fmt.Errorf("error serving: %w", serverErr) @@ -298,6 +385,8 @@ func crossProcessServe() error { return } + + fmt.Fprint(os.Stderr, "serving finished") }() cleanup := func() { @@ -318,7 +407,7 @@ func crossProcessServe() error { defer cleanup() exitSignal := make(chan os.Signal, 1) - signal.Notify(exitSignal, syscall.SIGUSR1) + signal.Notify(exitSignal, syscall.SIGTERM) defer signal.Stop(exitSignal) readyFile := os.NewFile(uintptr(6), "ready") @@ -330,7 +419,7 @@ func crossProcessServe() error { select { case <-ctx.Done(): - return fmt.Errorf("context done: %w: %w", ctx.Err(), context.Cause(ctx)) + return context.Cause(ctx) case <-exitSignal: return nil } diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go new file mode 100644 index 0000000000..16333261d3 --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_test.go @@ -0,0 +1,771 @@ +package userfaultfd + +import ( + "context" + "fmt" + "maps" + "math/rand" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/block" + "github.com/e2b-dev/infra/packages/orchestrator/internal/sandbox/uffd/memory" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +func TestNoOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(512) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(newMockFd(), h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, accessedOffsets, "checking which pages were faulted") + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + dirty := uffd.Dirty() + assert.Empty(t, slices.Collect(dirty.Offsets()), "checking dirty pages") +} + +func TestRandomOperations(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + numberOfOperations := 2048 + repetitions := 8 + + for i := range repetitions { + t.Run(fmt.Sprintf("Run_%d_of_%d", i+1, repetitions), func(t *testing.T) { + t.Parallel() + + // Use time-based seed for each run to ensure different random sequences + // This increases the chance of catching bugs that only manifest with specific sequences + seed := time.Now().UnixNano() + int64(i) + rng := rand.New(rand.NewSource(seed)) + + t.Logf("Using random seed: %d", seed) + + // Randomly operations on the data + operations := make([]operation, 0, numberOfOperations) + for range numberOfOperations { + operations = append(operations, operation{ + offset: int64(rng.Intn(int(numberOfPages-1)) * int(pagesize)), + mode: operationMode(rng.Intn(2) + 1), + }) + } + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + operations: operations, + }) + require.NoError(t, err) + + for _, operation := range operations { + err := h.executeOperation(t.Context(), operation) + require.NoError(t, err, "for operation %+v", operation) + } + + expectedAccessedOffsets := getOperationsOffsets(operations, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted (seed: %d)", seed) + + expectedDirtyOffsets := getOperationsOffsets(operations, operationModeWrite) + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty (seed: %d)", seed) + }) + } +} + +// Badly configured uffd panic recovery caused silent close of uffd—first operation (or parallel first operations) were always handled ok, but subsequent operations would freeze. +// This behavior was flaky with the rest of the tests, because it was racy. +func TestUffdNotClosingAfterOperation(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(4096) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + // We need different offset, because kernel would cache the same page for read. + write2 := operation{ + offset: int64(2 * header.PageSize), + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, dirtyOffsets, "checking which pages were dirty") + }) + + t.Run("missing read", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + // We need different offset, because kernel would cache the same page for read. + read2 := operation{ + offset: 2 * header.PageSize, + mode: operationModeRead, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Empty(t, dirtyOffsets, "checking which pages were dirty") + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{uint(2 * header.PageSize), 0}, accessedOffsets, "checking which pages were faulted") + }) + + t.Run("write protected", func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + read1 := operation{ + offset: 0, + mode: operationModeRead, + } + + read2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeRead, + } + + // We need at least 2 wp events to check the wp handler, so we need to write to 2 different pages. + + write1 := operation{ + offset: 0, + mode: operationModeWrite, + } + + write2 := operation{ + offset: 1 * header.PageSize, + mode: operationModeWrite, + } + + err = h.executeOperation(t.Context(), read1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), read2) + require.NoError(t, err) + + // Use the offset helpers to settle the requests. + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, accessedOffsets, "checking which pages were faulted") + + err = h.executeOperation(t.Context(), write1) + require.NoError(t, err) + + err = h.executeOperation(t.Context(), write2) + require.NoError(t, err) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.ElementsMatch(t, []uint{0, 1 * header.PageSize}, dirtyOffsets, "checking which pages were dirty") + }) +} + +func TestUffdEvents(t *testing.T) { + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + events := []event{ + // Same operation and offset, repeated (copies at 0), with both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, + offset: 0, + }, + { + UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: 0, + }, + + // WriteProtect at same offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, + offset: 0, + }, + + // Copy at next offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(header.PageSize), + }, + + // WriteProtect at next offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(header.PageSize), + }, + + // Copy at another offset, include both mode 0 and UFFDIO_COPY_MODE_WP + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: 0, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioCopy: &UffdioCopy{ + dst: 2 * header.PageSize, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, + offset: int64(2 * header.PageSize), + }, + + // WriteProtect at another offset, repeated + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 2 * header.PageSize, + len: header.PageSize, + }, + }, + offset: int64(2 * header.PageSize), + }, + } + + for _, event := range events { + err := event.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", event) + } + + receivedEvents := make([]event, 0, len(events)) + + for range events { + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + copyEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioCopy: ©Event.event}) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + writeProtectEvent.resolve() + + // We don't add the offset here, because it is propagated only through the "accessed" and "dirty" sets. + // When later comparing the events, we will compare the events without the offset. + receivedEvents = append(receivedEvents, event{UffdioWriteProtect: &writeProtectEvent.event}) + case <-t.Context().Done(): + t.FailNow() + } + } + + assert.Len(t, receivedEvents, len(events), "checking received events") + assert.ElementsMatch(t, zeroOffsets(events), receivedEvents, "checking received events") + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + default: + } + + dirty := uffd.Dirty() + + expectedDirtyOffsets := make(map[int64]struct{}) + expectedAccessedOffsets := make(map[int64]struct{}) + + for _, event := range events { + if event.UffdioWriteProtect != nil { + expectedDirtyOffsets[event.offset] = struct{}{} + } + if event.UffdioCopy != nil { + if event.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + expectedDirtyOffsets[event.offset] = struct{}{} + } + + expectedAccessedOffsets[event.offset] = struct{}{} + } + } + + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedDirtyOffsets)), slices.Collect(dirty.Offsets()), "checking dirty pages") + + accessed := accessed(uffd) + assert.ElementsMatch(t, slices.Collect(maps.Keys(expectedAccessedOffsets)), slices.Collect(accessed.Offsets()), "checking accessed pages") +} + +func TestUffdSettleRequests(t *testing.T) { + t.Parallel() + + pagesize := uint64(header.PageSize) + numberOfPages := uint64(32) + + h, err := configureCrossProcessTest(t, testConfig{ + pagesize: pagesize, + numberOfPages: numberOfPages, + }) + require.NoError(t, err) + + logger, err := logger.NewDevelopmentLogger() + require.NoError(t, err) + + testEventsSettle := func(t *testing.T, events []event) { + t.Helper() + + mockFd := newMockFd() + + // Placeholder uffd that does not serve anything + uffd, err := NewUserfaultfdFromFd(mockFd, h.data, &memory.Mapping{}, logger) + require.NoError(t, err) + + for _, e := range events { + err = e.trigger(t.Context(), uffd) + require.NoError(t, err, "for event %+v", e) + } + + var blockedCopyEvents []*blockedEvent[UffdioCopy] + var blockedWriteProtectEvents []*blockedEvent[UffdioWriteProtect] + + for range events { + // Wait until the event is blocked + select { + case copyEvent, ok := <-mockFd.copyCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, copyEvent.event, "copy event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioCopy: ©Event.event}, "checking copy event") + + blockedCopyEvents = append(blockedCopyEvents, copyEvent) + case writeProtectEvent, ok := <-mockFd.writeProtectCh: + if !ok { + t.FailNow() + } + + require.NotNil(t, writeProtectEvent.event, "write protect event should not be nil") + assert.Contains(t, zeroOffsets(events), event{UffdioWriteProtect: &writeProtectEvent.event}, "checking write protect event") + + blockedWriteProtectEvents = append(blockedWriteProtectEvents, writeProtectEvent) + case <-t.Context().Done(): + t.FailNow() + } + } + + require.Len(t, events, len(blockedCopyEvents)+len(blockedWriteProtectEvents), "checking blocked events") + + simulatedFCPause := make(chan struct{}) + + d := make(chan *block.Tracker) + + go func() { + acquired := uffd.settleRequests.TryLock() + assert.False(t, acquired, "settleRequests write lock should not be acquired") + + simulatedFCPause <- struct{}{} + + // This should block, until the events are resolved. + dirty := uffd.Dirty() + + select { + case d <- dirty: + case <-t.Context().Done(): + return + } + }() + + // This would be the place where the FC API Pause would return. + <-simulatedFCPause + + // Resolve the events to unblock getting the dirty pages in the goroutine. + for _, e := range blockedCopyEvents { + e.resolve() + } + + for _, e := range blockedWriteProtectEvents { + e.resolve() + } + + select { + case <-mockFd.copyCh: + t.Fatalf("copy channel should not have any events") + case <-mockFd.writeProtectCh: + t.Fatalf("write protect channel should not have any events") + case <-t.Context().Done(): + t.FailNow() + case dirty, ok := <-d: + if !ok { + t.FailNow() + } + + assert.ElementsMatch(t, dirtyOffsets(events), slices.Collect(dirty.Offsets()), "checking dirty pages") + } + } + + t.Run("missing", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: UFFDIO_COPY_MODE_WP, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("write protect", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("missing write", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + } + + testEventsSettle(t, events) + }) + + t.Run("event mix", func(t *testing.T) { + t.Parallel() + + events := []event{ + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + {UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 2 * int64(header.PageSize)}, + {UffdioCopy: &UffdioCopy{ + dst: 0, + len: header.PageSize, + mode: 0, + }, offset: 2 * int64(header.PageSize)}, + { + UffdioWriteProtect: &UffdioWriteProtect{ + _range: UffdioRange{ + start: 0, + len: header.PageSize, + }, + }, offset: 0, + }, + } + + testEventsSettle(t, events) + }) +} + +type event struct { + *UffdioCopy + *UffdioWriteProtect + + offset int64 +} + +func (e event) trigger(ctx context.Context, uffd *Userfaultfd) error { + switch { + case e.UffdioCopy != nil: + triggerMissing(ctx, uffd, *e.UffdioCopy, e.offset) + case e.UffdioWriteProtect != nil: + triggerWriteProtected(ctx, uffd, *e.UffdioWriteProtect, e.offset) + default: + return fmt.Errorf("invalid event: %+v", e) + } + + return nil +} + +// Return the event copy without the offset, because the offset is propagated only through the "accessed" and "dirty" sets, so direct comparisons would fail. +func (e event) withoutOffset() event { + return event{ + UffdioCopy: e.UffdioCopy, + UffdioWriteProtect: e.UffdioWriteProtect, + } +} + +// Creates a new slice of events with the offset set to 0, so we can compare the events without the offset. +func zeroOffsets(events []event) []event { + return utils.Map(events, func(e event) event { + return e.withoutOffset() + }) +} + +func triggerMissing(ctx context.Context, uffd *Userfaultfd, c UffdioCopy, offset int64) { + var write bool + + if c.mode != UFFDIO_COPY_MODE_WP { + write = true + } + + uffd.handleMissing( + ctx, + func() error { return nil }, + uintptr(c.dst), + uintptr(uffd.src.BlockSize()), + offset, + write, + ) +} + +func triggerWriteProtected(ctx context.Context, uffd *Userfaultfd, c UffdioWriteProtect, offset int64) { + uffd.handleWriteProtected( + ctx, + func() error { return nil }, + uintptr(c._range.start), + uintptr(uffd.src.BlockSize()), + offset, + ) +} + +func dirtyOffsets(events []event) []int64 { + offsets := make(map[int64]struct{}) + + for _, e := range events { + if e.UffdioWriteProtect != nil { + offsets[e.offset] = struct{}{} + } + + if e.UffdioCopy != nil { + if e.UffdioCopy.mode != UFFDIO_COPY_MODE_WP { + offsets[e.offset] = struct{}{} + } + } + } + + return slices.Collect(maps.Keys(offsets)) +} diff --git a/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_write_protection_test.go b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_write_protection_test.go new file mode 100644 index 0000000000..5b987fd95c --- /dev/null +++ b/packages/orchestrator/internal/sandbox/uffd/userfaultfd/userfaultfd_write_protection_test.go @@ -0,0 +1,320 @@ +package userfaultfd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +func TestWriteProtection(t *testing.T) { + t.Parallel() + + tests := []testConfig{ + { + name: "standard 4k page, operation at start", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at middle", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 15 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 15 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, operation at last page", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 31 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 31 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "standard 4k page, writes after reads, varying offsets", + pagesize: header.PageSize, + numberOfPages: 32, + operations: []operation{ + { + offset: 4 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.PageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.PageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at start", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 0, + mode: operationModeRead, + }, + { + offset: 0, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at middle", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 3 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 3 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, operation at last page", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 7 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 7 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + { + name: "hugepage, writes after reads, varying offsets", + pagesize: header.HugepageSize, + numberOfPages: 8, + operations: []operation{ + { + offset: 4 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeRead, + }, + { + offset: 4 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 5 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 2 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + { + offset: 0 * header.HugepageSize, + mode: operationModeWrite, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + for _, operation := range tt.operations { + err := h.executeOperation(t.Context(), operation) + assert.NoError(t, err, "for operation %+v", operation) //nolint:testifylint + } + + expectedAccessedOffsets := getOperationsOffsets(tt.operations, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets(tt.operations, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") + }) + } +} + +func TestParallelWriteProtection(t *testing.T) { + t.Parallel() + + parallelOperations := 1_000_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + readOp := operation{ + offset: 0, + mode: operationModeRead, + } + + err = h.executeOperation(t.Context(), readOp) + require.NoError(t, err) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + var verr errgroup.Group + + for range parallelOperations { + verr.Go(func() error { + return h.executeWrite(t.Context(), writeOp) + }) + } + + err = verr.Wait() + require.NoError(t, err) + + expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") +} + +func TestSerialWriteProtection(t *testing.T) { + t.Parallel() + + serialOperations := 10_000 + + tt := testConfig{ + pagesize: header.PageSize, + numberOfPages: 2, + } + + h, err := configureCrossProcessTest(t, tt) + require.NoError(t, err) + + writeOp := operation{ + offset: 0, + mode: operationModeWrite, + } + + for range serialOperations { + err := h.executeWrite(t.Context(), writeOp) + require.NoError(t, err) + } + + expectedAccessedOffsets := getOperationsOffsets([]operation{writeOp}, operationModeRead|operationModeWrite) + + accessedOffsets, err := h.accessedOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedAccessedOffsets, accessedOffsets, "checking which pages were faulted") + + expectedDirtyOffsets := getOperationsOffsets([]operation{writeOp}, operationModeWrite) + + dirtyOffsets, err := h.dirtyOffsetsOnce() + require.NoError(t, err) + + assert.Equal(t, expectedDirtyOffsets, dirtyOffsets, "checking which pages were dirty") +} diff --git a/packages/shared/pkg/fc/firecracker.yml b/packages/shared/pkg/fc/firecracker.yml index 525fb1677b..7dd777317c 100644 --- a/packages/shared/pkg/fc/firecracker.yml +++ b/packages/shared/pkg/fc/firecracker.yml @@ -5,7 +5,7 @@ info: The API is accessible through HTTP calls on specific URLs carrying JSON modeled data. The transport medium is a Unix Domain Socket. - version: 1.7.0-dev + version: 1.10.1 termsOfService: "" contact: email: "compute-capsule@amazon.com" @@ -85,12 +85,12 @@ paths: Will fail if update is not possible. operationId: putBalloon parameters: - - name: body - in: body - description: Balloon properties - required: true - schema: - $ref: "#/definitions/Balloon" + - name: body + in: body + description: Balloon properties + required: true + schema: + $ref: "#/definitions/Balloon" responses: 204: description: Balloon device created/updated @@ -109,12 +109,12 @@ paths: Will fail if update is not possible. operationId: patchBalloon parameters: - - name: body - in: body - description: Balloon properties - required: true - schema: - $ref: "#/definitions/BalloonUpdate" + - name: body + in: body + description: Balloon properties + required: true + schema: + $ref: "#/definitions/BalloonUpdate" responses: 204: description: Balloon device updated @@ -151,12 +151,12 @@ paths: Will fail if update is not possible. operationId: patchBalloonStatsInterval parameters: - - name: body - in: body - description: Balloon properties - required: true - schema: - $ref: "#/definitions/BalloonStatsUpdate" + - name: body + in: body + description: Balloon properties + required: true + schema: + $ref: "#/definitions/BalloonStatsUpdate" responses: 204: description: Balloon statistics interval updated @@ -220,6 +220,7 @@ paths: schema: $ref: "#/definitions/Error" + /drives/{drive_id}: put: summary: Creates or updates a drive. Pre-boot only. @@ -487,7 +488,8 @@ paths: /entropy: put: summary: Creates an entropy device. Pre-boot only. - description: Enables an entropy device that provides high-quality random data to the guest. + description: + Enables an entropy device that provides high-quality random data to the guest. operationId: putEntropyDevice parameters: - name: body @@ -504,10 +506,12 @@ paths: schema: $ref: "#/definitions/Error" + /network-interfaces/{iface_id}: put: summary: Creates a network interface. Pre-boot only. - description: Creates new network interface with ID specified by iface_id path parameter. + description: + Creates new network interface with ID specified by iface_id path parameter. operationId: putGuestNetworkInterfaceByID parameters: - name: iface_id @@ -534,7 +538,8 @@ paths: $ref: "#/definitions/Error" patch: summary: Updates the rate limiters applied to a network interface. Post-boot only. - description: Updates the rate limiters applied to a network interface. + description: + Updates the rate limiters applied to a network interface. operationId: patchGuestNetworkInterfaceByID parameters: - name: iface_id @@ -589,7 +594,8 @@ paths: /snapshot/load: put: summary: Loads a snapshot. Pre-boot only. - description: Loads the microVM state from a snapshot. + description: + Loads the microVM state from a snapshot. Only accepted on a fresh Firecracker process (before configuring any resource other than the Logger and Metrics). operationId: loadSnapshot @@ -629,7 +635,8 @@ paths: /vm: patch: summary: Updates the microVM state. - description: Sets the desired state (Paused or Resumed) for the microVM. + description: + Sets the desired state (Paused or Resumed) for the microVM. operationId: patchVm parameters: - name: body @@ -700,7 +707,8 @@ definitions: required: - amount_mib - deflate_on_oom - description: Balloon device descriptor. + description: + Balloon device descriptor. properties: amount_mib: type: integer @@ -716,7 +724,8 @@ definitions: type: object required: - amount_mib - description: Balloon device descriptor. + description: + Balloon device descriptor. properties: amount_mib: type: integer @@ -724,7 +733,8 @@ definitions: BalloonStats: type: object - description: Describes the balloon device statistics. + description: + Describes the balloon device statistics. required: - target_pages - actual_pages @@ -788,7 +798,8 @@ definitions: type: object required: - stats_polling_interval_s - description: Update the statistics polling interval, with the first statistics update scheduled immediately. Statistics cannot be turned on/off after boot. + description: + Update the statistics polling interval, with the first statistics update scheduled immediately. Statistics cannot be turned on/off after boot. properties: stats_polling_interval_s: type: integer @@ -798,7 +809,8 @@ definitions: type: object required: - kernel_image_path - description: Boot source descriptor. + description: + Boot source descriptor. properties: boot_args: type: string @@ -861,18 +873,21 @@ definitions: type: boolean cache_type: type: string - description: Represents the caching strategy for the block device. + description: + Represents the caching strategy for the block device. enum: ["Unsafe", "Writeback"] default: "Unsafe" # VirtioBlock specific parameters is_read_only: type: boolean - description: Is block read only. + description: + Is block read only. This field is required for virtio-block config and should be omitted for vhost-user-block configuration. path_on_host: type: string - description: Host level path for the guest drive. + description: + Host level path for the guest drive. This field is required for virtio-block config and should be omitted for vhost-user-block configuration. rate_limiter: $ref: "#/definitions/RateLimiter" @@ -888,7 +903,8 @@ definitions: # VhostUserBlock specific parameters socket: type: string - description: Path to the socket of vhost-user-block backend. + description: + Path to the socket of vhost-user-block backend. This field is required for vhost-user-block config should be omitted for virtio-block configuration. Error: @@ -911,6 +927,8 @@ definitions: $ref: "#/definitions/Drive" boot-source: $ref: "#/definitions/BootSource" + cpu-config: + $ref: "#/definitions/CpuConfig" logger: $ref: "#/definitions/Logger" machine-config: @@ -926,10 +944,13 @@ definitions: $ref: "#/definitions/NetworkInterface" vsock: $ref: "#/definitions/Vsock" + entropy: + $ref: "#/definitions/EntropyDevice" InstanceActionInfo: type: object - description: Variant wrapper containing the real action. + description: + Variant wrapper containing the real action. required: - action_type properties: @@ -943,7 +964,8 @@ definitions: InstanceInfo: type: object - description: Describes MicroVM instance information. + description: + Describes MicroVM instance information. required: - app_name - id @@ -968,10 +990,37 @@ definitions: vmm_version: description: MicroVM hypervisor build version. type: string + memory_regions: + type: array + description: The regions of the guest memory. + items: + $ref: "#/definitions/GuestMemoryRegionMapping" + + GuestMemoryRegionMapping: + type: object + description: Describes the region of guest memory that can be used for creating the memfile. + required: + - base_host_virt_addr + - size + - offset + - page_size + properties: + base_host_virt_addr: + type: integer + size: + description: The size of the region in bytes. + type: integer + offset: + description: The offset of the region in bytes. + type: integer + page_size: + description: The page size of the region in pages. + type: integer Logger: type: object - description: Describes the configuration option for the logging capability. + description: + Describes the configuration option for the logging capability. properties: level: type: string @@ -1005,6 +1054,9 @@ definitions: properties: cpu_template: $ref: "#/definitions/CpuTemplate" + # gdb_socket_path: + # type: string + # description: Path to the GDB socket. Requires the gdb feature to be enabled. smt: type: boolean description: Flag for enabling/disabling simultaneous multithreading. Can be enabled only on x86. @@ -1053,7 +1105,8 @@ definitions: Metrics: type: object - description: Describes the configuration option for the metrics capability. + description: + Describes the configuration option for the metrics capability. required: - metrics_path properties: @@ -1063,7 +1116,8 @@ definitions: MmdsConfig: type: object - description: Defines the MMDS configuration. + description: + Defines the MMDS configuration. required: - network_interfaces properties: @@ -1094,11 +1148,13 @@ definitions: MmdsContentsObject: type: object - description: Describes the contents of MMDS in JSON format. + description: + Describes the contents of MMDS in JSON format. NetworkInterface: type: object - description: Defines a network interface. + description: + Defines a network interface. required: - host_dev_name - iface_id @@ -1124,7 +1180,8 @@ definitions: type: string path_on_host: type: string - description: Host level path for the guest drive. + description: + Host level path for the guest drive. This field is optional for virtio-block config and should be omitted for vhost-user-block configuration. rate_limiter: $ref: "#/definitions/RateLimiter" @@ -1161,7 +1218,6 @@ definitions: SnapshotCreateParams: type: object required: - - mem_file_path - snapshot_path properties: mem_file_path: @@ -1189,7 +1245,8 @@ definitions: properties: enable_diff_snapshots: type: boolean - description: Enable support for incremental (diff) snapshots by tracking dirty guest pages. + description: + Enable support for incremental (diff) snapshots by tracking dirty guest pages. mem_file_path: type: string description: @@ -1207,7 +1264,8 @@ definitions: description: Path to the file that contains the microVM state to be loaded. resume_vm: type: boolean - description: When set to true, the vm is also resumed if the snapshot load is successful. + description: + When set to true, the vm is also resumed if the snapshot load is successful. TokenBucket: type: object @@ -1242,7 +1300,8 @@ definitions: Vm: type: object - description: Defines the microVM running state. It is especially useful in the snapshotting context. + description: + Defines the microVM running state. It is especially useful in the snapshotting context. required: - state properties: @@ -1254,14 +1313,16 @@ definitions: EntropyDevice: type: object - description: Defines an entropy device. + description: + Defines an entropy device. properties: rate_limiter: $ref: "#/definitions/RateLimiter" FirecrackerVersion: type: object - description: Describes the Firecracker version. + description: + Describes the Firecracker version. required: - firecracker_version properties: diff --git a/packages/shared/pkg/fc/models/full_vm_configuration.go b/packages/shared/pkg/fc/models/full_vm_configuration.go index 4ae633777e..3d99a9dfd5 100644 --- a/packages/shared/pkg/fc/models/full_vm_configuration.go +++ b/packages/shared/pkg/fc/models/full_vm_configuration.go @@ -25,9 +25,15 @@ type FullVMConfiguration struct { // boot source BootSource *BootSource `json:"boot-source,omitempty"` + // cpu config + CPUConfig *CPUConfig `json:"cpu-config,omitempty"` + // Configurations for all block devices. Drives []*Drive `json:"drives"` + // entropy + Entropy *EntropyDevice `json:"entropy,omitempty"` + // logger Logger *Logger `json:"logger,omitempty"` @@ -59,10 +65,18 @@ func (m *FullVMConfiguration) Validate(formats strfmt.Registry) error { res = append(res, err) } + if err := m.validateCPUConfig(formats); err != nil { + res = append(res, err) + } + if err := m.validateDrives(formats); err != nil { res = append(res, err) } + if err := m.validateEntropy(formats); err != nil { + res = append(res, err) + } + if err := m.validateLogger(formats); err != nil { res = append(res, err) } @@ -131,6 +145,25 @@ func (m *FullVMConfiguration) validateBootSource(formats strfmt.Registry) error return nil } +func (m *FullVMConfiguration) validateCPUConfig(formats strfmt.Registry) error { + if swag.IsZero(m.CPUConfig) { // not required + return nil + } + + if m.CPUConfig != nil { + if err := m.CPUConfig.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("cpu-config") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("cpu-config") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) validateDrives(formats strfmt.Registry) error { if swag.IsZero(m.Drives) { // not required return nil @@ -157,6 +190,25 @@ func (m *FullVMConfiguration) validateDrives(formats strfmt.Registry) error { return nil } +func (m *FullVMConfiguration) validateEntropy(formats strfmt.Registry) error { + if swag.IsZero(m.Entropy) { // not required + return nil + } + + if m.Entropy != nil { + if err := m.Entropy.Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("entropy") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("entropy") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) validateLogger(formats strfmt.Registry) error { if swag.IsZero(m.Logger) { // not required return nil @@ -290,10 +342,18 @@ func (m *FullVMConfiguration) ContextValidate(ctx context.Context, formats strfm res = append(res, err) } + if err := m.contextValidateCPUConfig(ctx, formats); err != nil { + res = append(res, err) + } + if err := m.contextValidateDrives(ctx, formats); err != nil { res = append(res, err) } + if err := m.contextValidateEntropy(ctx, formats); err != nil { + res = append(res, err) + } + if err := m.contextValidateLogger(ctx, formats); err != nil { res = append(res, err) } @@ -366,6 +426,27 @@ func (m *FullVMConfiguration) contextValidateBootSource(ctx context.Context, for return nil } +func (m *FullVMConfiguration) contextValidateCPUConfig(ctx context.Context, formats strfmt.Registry) error { + + if m.CPUConfig != nil { + + if swag.IsZero(m.CPUConfig) { // not required + return nil + } + + if err := m.CPUConfig.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("cpu-config") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("cpu-config") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) contextValidateDrives(ctx context.Context, formats strfmt.Registry) error { for i := 0; i < len(m.Drives); i++ { @@ -391,6 +472,27 @@ func (m *FullVMConfiguration) contextValidateDrives(ctx context.Context, formats return nil } +func (m *FullVMConfiguration) contextValidateEntropy(ctx context.Context, formats strfmt.Registry) error { + + if m.Entropy != nil { + + if swag.IsZero(m.Entropy) { // not required + return nil + } + + if err := m.Entropy.ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("entropy") + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("entropy") + } + return err + } + } + + return nil +} + func (m *FullVMConfiguration) contextValidateLogger(ctx context.Context, formats strfmt.Registry) error { if m.Logger != nil { diff --git a/packages/shared/pkg/fc/models/guest_memory_region_mapping.go b/packages/shared/pkg/fc/models/guest_memory_region_mapping.go new file mode 100644 index 0000000000..615edbfc50 --- /dev/null +++ b/packages/shared/pkg/fc/models/guest_memory_region_mapping.go @@ -0,0 +1,122 @@ +// Code generated by go-swagger; DO NOT EDIT. + +package models + +// This file was generated by the swagger tool. +// Editing this file might prove futile when you re-run the swagger generate command + +import ( + "context" + + "github.com/go-openapi/errors" + "github.com/go-openapi/strfmt" + "github.com/go-openapi/swag" + "github.com/go-openapi/validate" +) + +// GuestMemoryRegionMapping Describes the region of guest memory that can be used for creating the memfile. +// +// swagger:model GuestMemoryRegionMapping +type GuestMemoryRegionMapping struct { + + // base host virt addr + // Required: true + BaseHostVirtAddr *int64 `json:"base_host_virt_addr"` + + // The offset of the region in bytes. + // Required: true + Offset *int64 `json:"offset"` + + // The page size of the region in pages. + // Required: true + PageSize *int64 `json:"page_size"` + + // The size of the region in bytes. + // Required: true + Size *int64 `json:"size"` +} + +// Validate validates this guest memory region mapping +func (m *GuestMemoryRegionMapping) Validate(formats strfmt.Registry) error { + var res []error + + if err := m.validateBaseHostVirtAddr(formats); err != nil { + res = append(res, err) + } + + if err := m.validateOffset(formats); err != nil { + res = append(res, err) + } + + if err := m.validatePageSize(formats); err != nil { + res = append(res, err) + } + + if err := m.validateSize(formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *GuestMemoryRegionMapping) validateBaseHostVirtAddr(formats strfmt.Registry) error { + + if err := validate.Required("base_host_virt_addr", "body", m.BaseHostVirtAddr); err != nil { + return err + } + + return nil +} + +func (m *GuestMemoryRegionMapping) validateOffset(formats strfmt.Registry) error { + + if err := validate.Required("offset", "body", m.Offset); err != nil { + return err + } + + return nil +} + +func (m *GuestMemoryRegionMapping) validatePageSize(formats strfmt.Registry) error { + + if err := validate.Required("page_size", "body", m.PageSize); err != nil { + return err + } + + return nil +} + +func (m *GuestMemoryRegionMapping) validateSize(formats strfmt.Registry) error { + + if err := validate.Required("size", "body", m.Size); err != nil { + return err + } + + return nil +} + +// ContextValidate validates this guest memory region mapping based on context it is used +func (m *GuestMemoryRegionMapping) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + return nil +} + +// MarshalBinary interface implementation +func (m *GuestMemoryRegionMapping) MarshalBinary() ([]byte, error) { + if m == nil { + return nil, nil + } + return swag.WriteJSON(m) +} + +// UnmarshalBinary interface implementation +func (m *GuestMemoryRegionMapping) UnmarshalBinary(b []byte) error { + var res GuestMemoryRegionMapping + if err := swag.ReadJSON(b, &res); err != nil { + return err + } + *m = res + return nil +} diff --git a/packages/shared/pkg/fc/models/instance_info.go b/packages/shared/pkg/fc/models/instance_info.go index 1237bd8ec1..6499a6d0df 100644 --- a/packages/shared/pkg/fc/models/instance_info.go +++ b/packages/shared/pkg/fc/models/instance_info.go @@ -8,6 +8,7 @@ package models import ( "context" "encoding/json" + "strconv" "github.com/go-openapi/errors" "github.com/go-openapi/strfmt" @@ -28,6 +29,9 @@ type InstanceInfo struct { // Required: true ID *string `json:"id"` + // The regions of the guest memory. + MemoryRegions []*GuestMemoryRegionMapping `json:"memory_regions"` + // The current detailed state (Not started, Running, Paused) of the Firecracker instance. This value is read-only for the control-plane. // Required: true // Enum: ["Not started","Running","Paused"] @@ -50,6 +54,10 @@ func (m *InstanceInfo) Validate(formats strfmt.Registry) error { res = append(res, err) } + if err := m.validateMemoryRegions(formats); err != nil { + res = append(res, err) + } + if err := m.validateState(formats); err != nil { res = append(res, err) } @@ -82,6 +90,32 @@ func (m *InstanceInfo) validateID(formats strfmt.Registry) error { return nil } +func (m *InstanceInfo) validateMemoryRegions(formats strfmt.Registry) error { + if swag.IsZero(m.MemoryRegions) { // not required + return nil + } + + for i := 0; i < len(m.MemoryRegions); i++ { + if swag.IsZero(m.MemoryRegions[i]) { // not required + continue + } + + if m.MemoryRegions[i] != nil { + if err := m.MemoryRegions[i].Validate(formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("memory_regions" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("memory_regions" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + + return nil +} + var instanceInfoTypeStatePropEnum []interface{} func init() { @@ -137,8 +171,42 @@ func (m *InstanceInfo) validateVmmVersion(formats strfmt.Registry) error { return nil } -// ContextValidate validates this instance info based on context it is used +// ContextValidate validate this instance info based on the context it is used func (m *InstanceInfo) ContextValidate(ctx context.Context, formats strfmt.Registry) error { + var res []error + + if err := m.contextValidateMemoryRegions(ctx, formats); err != nil { + res = append(res, err) + } + + if len(res) > 0 { + return errors.CompositeValidationError(res...) + } + return nil +} + +func (m *InstanceInfo) contextValidateMemoryRegions(ctx context.Context, formats strfmt.Registry) error { + + for i := 0; i < len(m.MemoryRegions); i++ { + + if m.MemoryRegions[i] != nil { + + if swag.IsZero(m.MemoryRegions[i]) { // not required + return nil + } + + if err := m.MemoryRegions[i].ContextValidate(ctx, formats); err != nil { + if ve, ok := err.(*errors.Validation); ok { + return ve.ValidateName("memory_regions" + "." + strconv.Itoa(i)) + } else if ce, ok := err.(*errors.CompositeError); ok { + return ce.ValidateName("memory_regions" + "." + strconv.Itoa(i)) + } + return err + } + } + + } + return nil } diff --git a/packages/shared/pkg/fc/models/snapshot_create_params.go b/packages/shared/pkg/fc/models/snapshot_create_params.go index 0f06bee0e9..d5aaeb286c 100644 --- a/packages/shared/pkg/fc/models/snapshot_create_params.go +++ b/packages/shared/pkg/fc/models/snapshot_create_params.go @@ -21,8 +21,7 @@ import ( type SnapshotCreateParams struct { // Path to the file that will contain the guest memory. - // Required: true - MemFilePath *string `json:"mem_file_path"` + MemFilePath string `json:"mem_file_path,omitempty"` // Path to the file that will contain the microVM state. // Required: true @@ -37,10 +36,6 @@ type SnapshotCreateParams struct { func (m *SnapshotCreateParams) Validate(formats strfmt.Registry) error { var res []error - if err := m.validateMemFilePath(formats); err != nil { - res = append(res, err) - } - if err := m.validateSnapshotPath(formats); err != nil { res = append(res, err) } @@ -55,15 +50,6 @@ func (m *SnapshotCreateParams) Validate(formats strfmt.Registry) error { return nil } -func (m *SnapshotCreateParams) validateMemFilePath(formats strfmt.Registry) error { - - if err := validate.Required("mem_file_path", "body", m.MemFilePath); err != nil { - return err - } - - return nil -} - func (m *SnapshotCreateParams) validateSnapshotPath(formats strfmt.Registry) error { if err := validate.Required("snapshot_path", "body", m.SnapshotPath); err != nil { diff --git a/packages/shared/pkg/storage/temporary_memfile.go b/packages/shared/pkg/storage/temporary_memfile.go deleted file mode 100644 index a574f227e0..0000000000 --- a/packages/shared/pkg/storage/temporary_memfile.go +++ /dev/null @@ -1,61 +0,0 @@ -package storage - -import ( - "context" - "fmt" - "os" - "path/filepath" - "sync" - - "github.com/google/uuid" - "golang.org/x/sync/semaphore" - - "github.com/e2b-dev/infra/packages/shared/pkg/env" - "github.com/e2b-dev/infra/packages/shared/pkg/utils" -) - -var maxParallelMemfileSnapshotting = utils.Must(env.GetEnvAsInt("MAX_PARALLEL_MEMFILE_SNAPSHOTTING", 8)) - -var snapshotCacheQueue = semaphore.NewWeighted(int64(maxParallelMemfileSnapshotting)) - -type TemporaryMemfile struct { - path string - closeFn func() -} - -func AcquireTmpMemfile( - ctx context.Context, - config BuilderConfig, - buildID string, -) (*TemporaryMemfile, error) { - randomID := uuid.NewString() - - err := snapshotCacheQueue.Acquire(ctx, 1) - if err != nil { - return nil, fmt.Errorf("failed to acquire cache: %w", err) - } - releaseOnce := sync.OnceFunc(func() { - snapshotCacheQueue.Release(1) - }) - - return &TemporaryMemfile{ - path: cacheMemfileFullSnapshotPath(config, buildID, randomID), - closeFn: releaseOnce, - }, nil -} - -func (f *TemporaryMemfile) Path() string { - return f.path -} - -func (f *TemporaryMemfile) Close() error { - defer f.closeFn() - - return os.Remove(f.path) -} - -func cacheMemfileFullSnapshotPath(config BuilderConfig, buildID string, randomID string) string { - name := fmt.Sprintf("%s-%s-%s.full", buildID, MemfileName, randomID) - - return filepath.Join(config.GetSnapshotCacheDir(), name) -}