diff --git a/pkg/callstack/callstack.go b/pkg/callstack/callstack.go index b548180d6..6d5839157 100644 --- a/pkg/callstack/callstack.go +++ b/pkg/callstack/callstack.go @@ -24,6 +24,7 @@ import ( "strconv" "strings" + "github.com/rabbitstack/fibratus/pkg/sys" "github.com/rabbitstack/fibratus/pkg/util/va" "golang.org/x/arch/x86/x86asm" "golang.org/x/sys/windows" @@ -66,23 +67,34 @@ func (f *Frame) AllocationSize(proc windows.Handle) uint64 { return 0 } + pageCount := r.Size / pageSize + m := make([]sys.MemoryWorkingSetExInformation, pageCount) + for n := range pageCount { + addr := f.Addr.Inc(n * pageSize) + m[n].VirtualAddress = addr.Uintptr() + } + + ws := va.QueryWorkingSet(proc, m) + if ws == nil { + return 0 + } + var size uint64 // traverse all pages in the region - for n := uint64(0); n < r.Size; n += pageSize { - addr := f.Addr.Inc(n) - ws := va.QueryWorkingSet(proc, addr.Uint64()) - if ws == nil || !ws.Valid() { + for _, r := range ws { + attr := r.VirtualAttributes + if !attr.Valid() { continue } // use SharedOriginal after RS3/1709 if buildNumber >= 16299 { - if !ws.SharedOriginal() { + if !attr.SharedOriginal() { size += pageSize } } else { - if !ws.Shared() { + if !attr.Shared() { size += pageSize } } diff --git a/pkg/util/va/region.go b/pkg/util/va/region.go index e6616aa01..648f10b45 100644 --- a/pkg/util/va/region.go +++ b/pkg/util/va/region.go @@ -20,12 +20,13 @@ package va import ( "expvar" - "github.com/rabbitstack/fibratus/pkg/sys" - "golang.org/x/sys/windows" - "golang.org/x/time/rate" "strconv" "sync" "unsafe" + + "github.com/rabbitstack/fibratus/pkg/sys" + "golang.org/x/sys/windows" + "golang.org/x/time/rate" ) const ( @@ -183,19 +184,6 @@ func VirtualQuery(process windows.Handle, addr uint64) *RegionInfo { } } -// QueryWorkingSet retrieves extended information about -// the pages at specific virtual addresses in the address -// space of the specified process. -func QueryWorkingSet(process windows.Handle, addr uint64) *sys.MemoryWorkingSetExBlock { - var ws sys.MemoryWorkingSetExInformation - ws.VirtualAddress = uintptr(addr) - err := sys.QueryWorkingSet(process, &ws, uint32(unsafe.Sizeof(sys.MemoryWorkingSetExInformation{}))) - if err != nil { - return nil - } - return &ws.VirtualAttributes -} - // Remove removes the process handle from cache and closes it. // It returns true if the handle was closed successfully. func (p *RegionProber) Remove(pid uint32) bool { diff --git a/pkg/util/va/region_test.go b/pkg/util/va/region_test.go index 74be473b2..a0f7a9cef 100644 --- a/pkg/util/va/region_test.go +++ b/pkg/util/va/region_test.go @@ -19,13 +19,14 @@ package va import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sys/windows" "os" "testing" "time" "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" ) func TestRegionProber(t *testing.T) { @@ -125,19 +126,6 @@ func TestReadArea(t *testing.T) { require.True(t, Zeroed(zeroArea)) } -func TestQueryWorkingSet(t *testing.T) { - addr, err := getModuleBaseAddress(uint32(os.Getpid())) - require.NoError(t, err) - - b := QueryWorkingSet(windows.CurrentProcess(), uint64(addr)) - require.NotNil(t, b) - - require.True(t, b.Valid()) - require.False(t, b.Bad()) - require.True(t, b.SharedOriginal()) - require.True(t, (b.Win32Protection()&windows.PAGE_READONLY) != 0) -} - func getModuleBaseAddress(pid uint32) (uintptr, error) { var moduleHandles [1024]windows.Handle var cbNeeded uint32 diff --git a/pkg/util/va/ws.go b/pkg/util/va/ws.go new file mode 100644 index 000000000..eddbf54de --- /dev/null +++ b/pkg/util/va/ws.go @@ -0,0 +1,329 @@ +/* + * Copyright 2021-present by Nedim Sabic Sabic + * https://www.fibratus.io + * All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package va + +import ( + "expvar" + "runtime" + "slices" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/rabbitstack/fibratus/pkg/sys" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +var ( + workingsetOpsCount = expvar.NewInt("workingset.ops.count") + workingsetTimeoutCount = expvar.NewInt("workingset.timeout.count") +) + +// jitter dynamically adjust the wait timeout for the worker thread. +type jitter struct { + mu sync.Mutex + p50 time.Duration // median observed duration + p99 time.Duration // 99th percentile observed + samples []time.Duration + maxSamples int +} + +// ctx contains all arguments submitted to the QueryWorkingSet API. +type ctx struct { + proc windows.Handle + ws []sys.MemoryWorkingSetExInformation + size uint32 // byte size of the ws slice + done uint32 // 1 if completed successfully +} + +// worker represents a single reusable native OS thread. +type worker struct { + in windows.Handle // caller signals work is ready + out windows.Handle // thread signals work is done + quit windows.Handle // signals the callback to exit + h windows.Handle + ctx *ctx // shared work context + jitter *jitter +} + +// pool is a fixed-size pool of native OS threads to query working set regions. +type pool struct { + mu sync.Mutex + workers []*worker + free chan *worker + size int +} + +func (j *jitter) record(d time.Duration) { + j.mu.Lock() + defer j.mu.Unlock() + j.samples = append(j.samples, d) + if len(j.samples) > j.maxSamples { + j.samples = j.samples[1:] + } + // recompute p99 + sorted := append([]time.Duration{}, j.samples...) + slices.Sort(sorted) + j.p99 = sorted[int(float64(len(sorted))*0.99)] +} + +func (j *jitter) timeout() uint32 { + j.mu.Lock() + defer j.mu.Unlock() + if j.p99 == 0 { + return 100 // default before enough samples + } + // timeout = p99 * 3, clamped between 50ms and 500ms + v := j.p99 * 3 + v = max(v, 50*time.Millisecond) + v = min(v, 500*time.Millisecond) + return uint32(v.Milliseconds()) +} + +var callback = windows.NewCallback(func(param uintptr) uintptr { + w := (*worker)(unsafe.Pointer(param)) + + // wait on both in and quit to never block on INFINITE with a single event + handles := []windows.Handle{w.in, w.quit} + for { + // sleep until caller gives us work + wait, _ := windows.WaitForMultipleObjects(handles, false, windows.INFINITE) + switch wait { + case windows.WAIT_OBJECT_0: + if err := sys.QueryWorkingSet(w.ctx.proc, &w.ctx.ws[0], w.ctx.size); err == nil { + atomic.StoreUint32(&w.ctx.done, 1) + } + // signal caller that work is complete + windows.SetEvent(w.out) + case windows.WAIT_OBJECT_0 + 1: // w.quit signaled so exit cleanly + return 0 + default: + return 1 + } + } +}) + +func newPool(size int) (*pool, error) { + p := &pool{ + free: make(chan *worker, size), + size: size, + } + for i := 0; i < size; i++ { + w, err := newWorker() + if err != nil { + p.close() + return nil, err + } + p.workers = append(p.workers, w) + p.free <- w + } + return p, nil +} + +// acquire grabs a free worker, or returns nil if none available +func (p *pool) acquire(timeout time.Duration) *worker { + select { + case w := <-p.free: + return w + case <-time.After(timeout): + return nil + } +} + +// release returns a healthy worker to the pool. +// If the worker is poisoned (stalled), it is evicted and replaced. +func (p *pool) release(w *worker, poisoned bool) { + if !poisoned { + // reset events and context for reuse + windows.ResetEvent(w.in) + windows.ResetEvent(w.out) + atomic.StoreUint32(&w.ctx.done, 0) + p.free <- w + return + } + + // evict the stuck thread + w.close() + + // spawn a replacement asynchronously to keep pool size stable + go func() { + replacement, err := newWorker() + if err != nil { + // pool shrinks by 1 + log.Warnf("unable to spawn replacement worker: %v", err) + return + } + p.free <- replacement + }() +} + +func (p *pool) close() { + close(p.free) + for w := range p.free { + w.close() + } +} + +func newWorker() (*worker, error) { + w := &worker{ + ctx: &ctx{}, + jitter: &jitter{maxSamples: 1000}, + } + + var err error + w.in, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, err + } + w.out, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + windows.CloseHandle(w.out) + return nil, err + } + w.quit, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + windows.CloseHandle(w.in) + windows.CloseHandle(w.out) + return nil, err + } + + w.h = sys.CreateThread( + nil, + 0, + callback, + uintptr(unsafe.Pointer(w)), + 0, + nil) + if w.h == 0 { + w.close() + return nil, err + } + + return w, nil +} + +func (w *worker) close() { + if w.quit != 0 { + // signal the callback to exit gracefully first + windows.SetEvent(w.quit) + // give it a moment to exit before forcing termination + if w.h != 0 { + if wait, _ := windows.WaitForSingleObject(w.h, 200); wait != windows.WAIT_OBJECT_0 { + sys.TerminateThread(w.h, 0) // last resort + } + } + } + if w.h != 0 { + sys.TerminateThread(w.h, 0) + windows.CloseHandle(w.h) + } + if w.in != 0 { + windows.CloseHandle(w.in) + } + if w.out != 0 { + windows.CloseHandle(w.out) + } + if w.quit != 0 { + windows.CloseHandle(w.quit) + } +} + +func (w *worker) submit(proc windows.Handle, ws []sys.MemoryWorkingSetExInformation) { + size := uint32(len(ws)) * uint32(unsafe.Sizeof(sys.MemoryWorkingSetExInformation{})) + w.ctx.proc = proc + w.ctx.size = size + w.ctx.ws = ws + windows.SetEvent(w.in) +} + +func (w *worker) wait() (uint32, error) { + return windows.WaitForSingleObject(w.out, w.jitter.timeout()) +} + +func poolSize() int { + n := runtime.NumCPU() / 2 + if n < 2 { + return 2 + } + if n > 8 { + return 8 + } + return n +} + +var p *pool +var poolOnce sync.Once + +// QueryWorkingSet returns working set information for a set of addresses. +func QueryWorkingSet(proc windows.Handle, ws []sys.MemoryWorkingSetExInformation) []sys.MemoryWorkingSetExInformation { + poolOnce.Do(func() { + var err error + p, err = newPool(poolSize()) + if err != nil { + log.Errorf("unable to create working set pool: %v", err) + } + }) + + if p == nil { + return nil + } + + // acquire a worker and don't wait forever if pool is exhausted + w := p.acquire(50 * time.Millisecond) + if w == nil { + return nil // pool exhausted + } + + // post work to the thread + w.submit(proc, ws) + + // wait for completion + start := time.Now() + wait, err := w.wait() + if err != nil { + return nil + } + + switch wait { + case windows.WAIT_OBJECT_0: + workingsetOpsCount.Add(1) + // feed successful durations back + w.jitter.record(time.Since(start)) + defer p.release(w, false) + + if atomic.LoadUint32(&w.ctx.done) == 0 { + return nil + } + return w.ctx.ws + + case sys.WaitTimeout: + workingsetTimeoutCount.Add(1) + // thread is stalled inside the kernel syscall. + // Safe to terminate: no user-mode locks held, + // no Go runtime state to corrupt. + p.release(w, true) + return nil + + default: + p.release(w, true) + return nil + } +} diff --git a/pkg/util/va/ws_test.go b/pkg/util/va/ws_test.go new file mode 100644 index 000000000..31db5ca4b --- /dev/null +++ b/pkg/util/va/ws_test.go @@ -0,0 +1,508 @@ +//go:build windows + +package va + +import ( + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/rabbitstack/fibratus/pkg/sys" + "golang.org/x/sys/windows" +) + +func TestJitterDefaultTimeout(t *testing.T) { + j := &jitter{maxSamples: 10} + // no samples recorded yet + got := j.timeout() + if got != 100 { + t.Fatalf("expected default timeout 100ms, got %d", got) + } +} + +func TestJitterTimeoutClampedToMinimum(t *testing.T) { + j := &jitter{maxSamples: 10} + // record very short durations to p99*3 would be below 50ms floor + for i := 0; i < 10; i++ { + j.record(1 * time.Millisecond) + } + got := j.timeout() + if got < 50 { + t.Fatalf("timeout %dms is below minimum 50ms", got) + } +} + +func TestJitterTimeoutClampedToMaximum(t *testing.T) { + j := &jitter{maxSamples: 10} + // record very long durations (p99*3 would exceed 500ms ceiling) + for i := 0; i < 10; i++ { + j.record(1 * time.Second) + } + got := j.timeout() + if got > 500 { + t.Fatalf("timeout %dms exceeds maximum 500ms", got) + } +} + +func TestJitterTimeoutScalesWithP99(t *testing.T) { + j := &jitter{maxSamples: 100} + // record 99 short samples and 1 long outlier + for i := 0; i < 99; i++ { + j.record(10 * time.Millisecond) + } + j.record(80 * time.Millisecond) // the p99 outlier + + got := j.timeout() + if got < 50 || got > 500 { + t.Fatalf("timeout %dms out of expected range [50, 500]", got) + } + // should be noticeably larger than the default 100ms + if got <= 100 { + t.Fatalf("timeout %dms should reflect the p99 outlier, expected > 100ms", got) + } +} + +func TestJitterSampleWindowIsBounded(t *testing.T) { + max := 5 + j := &jitter{maxSamples: max} + // flood with more samples than the window allows + for i := 0; i < max*3; i++ { + j.record(time.Duration(i) * time.Millisecond) + } + j.mu.Lock() + defer j.mu.Unlock() + if len(j.samples) > max { + t.Fatalf("sample buffer grew beyond maxSamples: got %d, want <= %d", len(j.samples), max) + } +} + +func TestJitterConcurrentRecordAndTimeout(t *testing.T) { + j := &jitter{maxSamples: 100} + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(2) + go func(d time.Duration) { + defer wg.Done() + j.record(d) + }(time.Duration(i) * time.Millisecond) + go func() { + defer wg.Done() + j.timeout() // must not race or panic + }() + } + wg.Wait() +} + +func TestPoolSizeMinimum(t *testing.T) { + // regardless of CPU count the pool must be at least 2 + got := poolSize() + if got < 2 { + t.Fatalf("poolSize returned %d, want >= 2", got) + } +} + +func TestPoolSizeMaximum(t *testing.T) { + got := poolSize() + if got > 8 { + t.Fatalf("poolSize returned %d, want <= 8", got) + } +} + +func TestPoolSizeIsHalfCPU(t *testing.T) { + got := poolSize() + half := runtime.NumCPU() / 2 + // clamp to [2, 8] as the implementation does + if half < 2 { + half = 2 + } + if half > 8 { + half = 8 + } + if got != half { + t.Fatalf("poolSize = %d, want %d (half of %d CPUs, clamped)", got, half, runtime.NumCPU()) + } +} + +func TestNewWorkerCreatesValidHandles(t *testing.T) { + w, err := newWorker() + if err != nil { + t.Fatalf("newWorker failed: %v", err) + } + defer w.close() + + if w.h == 0 { + t.Error("thread handle is zero") + } + if w.in == 0 { + t.Error("input event handle is zero") + } + if w.out == 0 { + t.Error("output event handle is zero") + } + if w.ctx == nil { + t.Error("ctx is nil") + } + if w.jitter == nil { + t.Error("jitter is nil") + } +} + +func TestWorkerCloseIsIdempotent(t *testing.T) { + w, err := newWorker() + if err != nil { + t.Fatalf("newWorker failed: %v", err) + } + // closing twice must not panic or crash + w.close() + w.close() +} + +func TestWorkerSubmitSetsContext(t *testing.T) { + w, err := newWorker() + if err != nil { + t.Fatalf("newWorker failed: %v", err) + } + defer w.close() + + ws := make([]sys.MemoryWorkingSetExInformation, 2) + ws[0].VirtualAddress = 0x1000 + ws[1].VirtualAddress = 0x2000 + + proc := windows.CurrentProcess() + w.submit(proc, ws) + + if w.ctx.proc != proc { + t.Error("ctx.proc not set correctly") + } + expectedSize := uint32(2) * uint32(unsafe.Sizeof(sys.MemoryWorkingSetExInformation{})) + if w.ctx.size != expectedSize { + t.Errorf("ctx.size = %d, want %d", w.ctx.size, expectedSize) + } +} + +func TestNewPoolCreatesCorrectNumberOfWorkers(t *testing.T) { + size := 3 + p, err := newPool(size) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + if len(p.workers) != size { + t.Errorf("worker count = %d, want %d", len(p.workers), size) + } + if cap(p.free) != size { + t.Errorf("free channel capacity = %d, want %d", cap(p.free), size) + } +} + +func TestPoolAcquireReturnsWorker(t *testing.T) { + p, err := newPool(2) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + w := p.acquire(50 * time.Millisecond) + if w == nil { + t.Fatal("acquire returned nil but pool has free workers") + } +} + +func TestPoolAcquireReturnsNilWhenExhausted(t *testing.T) { + p, err := newPool(2) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + // drain all workers + w1 := p.acquire(50 * time.Millisecond) + w2 := p.acquire(50 * time.Millisecond) + if w1 == nil || w2 == nil { + t.Fatal("expected to acquire both workers") + } + + // pool is now empty so must return nil within timeout + start := time.Now() + w3 := p.acquire(50 * time.Millisecond) + elapsed := time.Since(start) + + if w3 != nil { + t.Error("expected nil from exhausted pool, got a worker") + } + if elapsed > 200*time.Millisecond { + t.Errorf("acquire blocked for %v, expected ~50ms timeout", elapsed) + } +} + +func TestPoolReleaseHealthyWorkerReturnsItToPool(t *testing.T) { + p, err := newPool(2) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + w := p.acquire(50 * time.Millisecond) + if w == nil { + t.Fatal("acquire returned nil") + } + + p.release(w, false) + + // worker should be back in the pool + w2 := p.acquire(50 * time.Millisecond) + if w2 == nil { + t.Fatal("expected to re-acquire released worker") + } +} + +func TestPoolReleaseHealthyWorkerResetsContext(t *testing.T) { + p, err := newPool(1) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + w := p.acquire(50 * time.Millisecond) + if w == nil { + t.Fatal("acquire returned nil") + } + // simulate completed work + atomic.StoreUint32(&w.ctx.done, 1) + + p.release(w, false) + + // re-acquire and verify done flag is cleared + w2 := p.acquire(50 * time.Millisecond) + if w2 == nil { + t.Fatal("re-acquire failed") + } + if atomic.LoadUint32(&w2.ctx.done) != 0 { + t.Error("ctx.done was not reset on release") + } +} + +func TestPoolReleasePoisonedWorkerSpawnsReplacement(t *testing.T) { + p, err := newPool(1) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + w := p.acquire(50 * time.Millisecond) + if w == nil { + t.Fatal("acquire returned nil") + } + + // evict as poisoned + p.release(w, true) + + // wait for the replacement goroutine to put a new worker in the pool + replacement := p.acquire(2 * time.Second) + if replacement == nil { + t.Fatal("replacement worker was not added to pool after eviction") + } +} + +func TestPoolConcurrentAcquireRelease(t *testing.T) { + p, err := newPool(4) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer p.close() + + var wg sync.WaitGroup + var acquired atomic.Int32 + + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + defer wg.Done() + w := p.acquire(100 * time.Millisecond) + if w == nil { + return + } + acquired.Add(1) + time.Sleep(5 * time.Millisecond) // simulate work + p.release(w, false) + }() + } + + wg.Wait() + if acquired.Load() == 0 { + t.Error("no workers were ever acquired") + } +} + +func TestQueryWorkingSetCurrentProcess(t *testing.T) { + // allocate a page so we have a known committed address to query + const size = 4096 + addr, err := windows.VirtualAlloc(0, size, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + t.Fatalf("VirtualAlloc failed: %v", err) + } + defer windows.VirtualFree(addr, 0, windows.MEM_RELEASE) + + ws := []sys.MemoryWorkingSetExInformation{ + {VirtualAddress: addr}, + } + + result := QueryWorkingSet(windows.CurrentProcess(), ws) + if result == nil { + t.Fatal("QueryWorkingSet returned nil for a valid committed page") + } + if len(result) != 1 { + t.Fatalf("expected 1 result, got %d", len(result)) + } +} + +func TestQueryWorkingSetBatchAddresses(t *testing.T) { + const pageSize = 4096 + const pages = 8 + + addr, err := windows.VirtualAlloc(0, pageSize*pages, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + t.Fatalf("VirtualAlloc failed: %v", err) + } + defer windows.VirtualFree(addr, 0, windows.MEM_RELEASE) + + ws := make([]sys.MemoryWorkingSetExInformation, pages) + for i := range ws { + ws[i].VirtualAddress = addr + uintptr(i*pageSize) + } + + result := QueryWorkingSet(windows.CurrentProcess(), ws) + if result == nil { + t.Fatal("QueryWorkingSet returned nil for batch of valid pages") + } + if len(result) != pages { + t.Fatalf("expected %d results, got %d", pages, len(result)) + } +} + +func TestQueryWorkingSetNilOnNilPool(t *testing.T) { + // temporarily nil the pool to simulate init failure + saved := p + p = nil + defer func() { p = saved }() + + ws := []sys.MemoryWorkingSetExInformation{{VirtualAddress: 0x1000}} + result := QueryWorkingSet(windows.CurrentProcess(), ws) + if result != nil { + t.Error("expected nil result when pool is nil") + } +} + +func TestQueryWorkingSetPoolExhaustion(t *testing.T) { + // create a tiny pool and exhaust it + tiny, err := newPool(1) + if err != nil { + t.Fatalf("newPool failed: %v", err) + } + defer tiny.close() + + saved := p + p = tiny + defer func() { p = saved }() + + // hold the only worker + held := tiny.acquire(50 * time.Millisecond) + if held == nil { + t.Fatal("could not acquire the only worker") + } + defer tiny.release(held, false) + + // now QueryWorkingSet should fail gracefully + ws := []sys.MemoryWorkingSetExInformation{{VirtualAddress: 0x1000}} + result := QueryWorkingSet(windows.CurrentProcess(), ws) + if result != nil { + t.Error("expected nil when pool is exhausted") + } +} + +func TestQueryWorkingSetPoolSingletonInit(t *testing.T) { + // reset the singleton so poolOnce triggers again + poolOnce = sync.Once{} + p = nil + + ws := []sys.MemoryWorkingSetExInformation{{VirtualAddress: 0x1000}} + // just ensure the lazy init path doesn't panic + _ = QueryWorkingSet(windows.CurrentProcess(), ws) + + if p == nil { + t.Error("pool was not initialized by QueryWorkingSet") + } +} + +func TestQueryWorkingSetConcurrentCalls(t *testing.T) { + const pageSize = 4096 + addr, err := windows.VirtualAlloc(0, pageSize, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + t.Fatalf("VirtualAlloc failed: %v", err) + } + defer windows.VirtualFree(addr, 0, windows.MEM_RELEASE) + + var wg sync.WaitGroup + var successCount atomic.Int32 + + for i := 0; i < 30; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ws := []sys.MemoryWorkingSetExInformation{{VirtualAddress: addr}} + if QueryWorkingSet(windows.CurrentProcess(), ws) != nil { + successCount.Add(1) + } + }() + } + + wg.Wait() + // with a pool of poolSize() workers, some calls will be dropped under + // contention but at least some must succeed + if successCount.Load() == 0 { + t.Error("all concurrent QueryWorkingSet calls failed") + } +} + +func BenchmarkQueryWorkingSetSinglePage(b *testing.B) { + const pageSize = 4096 + addr, err := windows.VirtualAlloc(0, pageSize, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + b.Fatalf("VirtualAlloc failed: %v", err) + } + defer windows.VirtualFree(addr, 0, windows.MEM_RELEASE) + + ws := []sys.MemoryWorkingSetExInformation{{VirtualAddress: addr}} + proc := windows.CurrentProcess() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + QueryWorkingSet(proc, ws) + } +} + +func BenchmarkQueryWorkingSetBatch64Pages(b *testing.B) { + const pageSize = 4096 + const pages = 64 + + addr, err := windows.VirtualAlloc(0, pageSize*pages, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + b.Fatalf("VirtualAlloc failed: %v", err) + } + defer windows.VirtualFree(addr, 0, windows.MEM_RELEASE) + + ws := make([]sys.MemoryWorkingSetExInformation, pages) + for i := range ws { + ws[i].VirtualAddress = addr + uintptr(i*pageSize) + } + proc := windows.CurrentProcess() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + QueryWorkingSet(proc, ws) + } +}