Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion pkg/ps/types/types_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ type PS struct {
// IsCreatedFromSystemLogger is the metadata attribute that indicates if the
// process state is created from the event published by the NT kernel logger.
IsCreatedFromSystemLogger bool `json:"-"`

// modules are process modules obtained by direct invocation to the API call.
modules []sys.ProcessModule
onceMods sync.Once
}

// UUID is meant to offer a more robust version of process ID that
Expand Down Expand Up @@ -497,11 +501,15 @@ func (ps *PS) AddModule(mod Module) {
if m != nil {
return
}
ps.Lock()
defer ps.Unlock()
ps.Modules = append(ps.Modules, mod)
}

// RemoveModule removes a specified module from this process state.
func (ps *PS) RemoveModule(addr va.Address) {
ps.Lock()
defer ps.Unlock()
for i, mod := range ps.Modules {
if mod.BaseAddress == addr {
ps.Modules = append(ps.Modules[:i], ps.Modules[i+1:]...)
Expand All @@ -512,6 +520,8 @@ func (ps *PS) RemoveModule(addr va.Address) {

// FindModule finds the module by name.
func (ps *PS) FindModule(path string) *Module {
ps.RLock()
defer ps.RUnlock()
for _, mod := range ps.Modules {
if filepath.Base(mod.Name) == filepath.Base(path) {
return &mod
Expand All @@ -522,6 +532,8 @@ func (ps *PS) FindModule(path string) *Module {

// FindModuleByAddr finds the module by its base address.
func (ps *PS) FindModuleByAddr(addr va.Address) *Module {
ps.RLock()
defer ps.RUnlock()
for _, mod := range ps.Modules {
if mod.BaseAddress == addr {
return &mod
Expand All @@ -530,11 +542,57 @@ func (ps *PS) FindModuleByAddr(addr va.Address) *Module {
return nil
}

var queryLiveModules = func(pid uint32) []sys.ProcessModule {
return sys.EnumProcessModules(pid)
}

// FindModuleByVa finds the module name by
// probing the range of the given virtual address.
func (ps *PS) FindModuleByVa(addr va.Address) *Module {
mod := ps.findModuleByVa(addr)
if mod != nil {
return mod
}

ps.onceMods.Do(func() {
// query live process modules
ps.modules = queryLiveModules(ps.PID)
})

// try to find the module within the VA space
// and if found, add it to process modules for
// future lookups
for _, m := range ps.modules {
b := va.Address(m.BaseOfDll)
size := uint64(m.SizeOfImage)

if addr < b || addr >= b.Inc(size) {
continue
}

mod := Module{
Name: m.Name,
BaseAddress: b,
Size: size,
DefaultBaseAddress: b,
}

ps.Lock()
ps.Modules = append(ps.Modules, mod)
ps.Unlock()

return &mod
}

return nil
}

func (ps *PS) findModuleByVa(addr va.Address) *Module {
ps.RLock()
defer ps.RUnlock()
for _, mod := range ps.Modules {
if addr >= mod.BaseAddress && addr <= mod.BaseAddress.Inc(mod.Size) {
end := mod.BaseAddress.Inc(mod.Size)
if addr >= mod.BaseAddress && addr < end {
return &mod
}
}
Expand Down
167 changes: 167 additions & 0 deletions pkg/ps/types/types_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ package types

import (
"os"
"sync"
"testing"
"time"

"github.com/rabbitstack/fibratus/pkg/sys"
"github.com/rabbitstack/fibratus/pkg/util/bootid"
"github.com/rabbitstack/fibratus/pkg/util/va"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows"
Expand Down Expand Up @@ -130,3 +133,167 @@ func TestIsAppinfoSvc(t *testing.T) {
})
}
}

func TestFindModuleByVa(t *testing.T) {
base := va.Address(0x1000)

tests := []struct {
name string
initialModules []Module
liveModules []sys.ProcessModule
addr va.Address
expectNil bool
expectName string
expectCachedAdd bool
}{
{
name: "hit lower bound inclusive",
initialModules: []Module{
{
Name: "C:\\Windows\\System32\\ntdll.dll",
BaseAddress: base,
Size: 0x200,
},
},
addr: base,
expectName: "C:\\Windows\\System32\\ntdll.dll",
},
{
name: "hit upper bound exclusive",
initialModules: []Module{
{
Name: "C:\\Windows\\System32\\ntdll.dll",
BaseAddress: base,
Size: 0x200,
},
},
addr: base.Inc(0x200),
expectNil: true,
},
{
name: "address inside range",
initialModules: []Module{
{
Name: "C:\\Windows\\System32\\ntdll.dll",
BaseAddress: base,
Size: 0x200,
},
{
Name: "C:\\Windows\\System32\\kernel32.dll",
BaseAddress: base.Inc(10),
Size: 0x100,
},
},
addr: base.Inc(0x100),
expectName: "C:\\Windows\\System32\\ntdll.dll",
},
{
name: "miss cached but hit live modules",
liveModules: []sys.ProcessModule{
{
ModuleInfo: windows.ModuleInfo{
BaseOfDll: 0x2000,
SizeOfImage: 0x300,
},
Name: "C:\\Windows\\System32\\ntdll.dll",
},
},
addr: va.Address(0x2100),
expectName: "C:\\Windows\\System32\\ntdll.dll",
expectCachedAdd: true,
},
{
name: "miss both cached and live",
initialModules: []Module{
{
Name: "C:\\Windows\\System32\\ntdll.dll",
BaseAddress: base,
Size: 0x200,
},
},
liveModules: []sys.ProcessModule{
{
ModuleInfo: windows.ModuleInfo{
BaseOfDll: 0x3000,
SizeOfImage: 0x100,
},
Name: "C:\\Windows\\System32\\ntdll.dll",
},
},
addr: va.Address(0x9999),
expectNil: true,
},
{
name: "multiple modules choose correct one",
initialModules: []Module{
{
Name: "C:\\Windows\\System32\\ntdll.dll",
BaseAddress: va.Address(0x1000),
Size: 0x100,
},
{
Name: "C:\\Windows\\System32\\kernel32.dll",
BaseAddress: va.Address(0x2000),
Size: 0x200,
},
},
addr: va.Address(0x2100),
expectName: "C:\\Windows\\System32\\kernel32.dll",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ps := &PS{
PID: 1234,
Modules: append([]Module{}, tt.initialModules...),
}

ps.onceMods = sync.Once{}

queryLiveModules = func(_ uint32) []sys.ProcessModule {
var mods []sys.ProcessModule
for _, m := range tt.liveModules {
mods = append(mods, sys.ProcessModule{
ModuleInfo: windows.ModuleInfo{
BaseOfDll: m.BaseOfDll,
SizeOfImage: m.SizeOfImage,
},
Name: m.Name,
})
}
return mods
}

mod := ps.FindModuleByVa(tt.addr)

if tt.expectNil {
if mod != nil {
t.Fatalf("expected nil, got %+v", mod)
}
return
}

if mod == nil {
t.Fatalf("expected module %s, got nil", tt.expectName)
}

if mod.Name != tt.expectName {
t.Fatalf("expected module %s, got %s", tt.expectName, mod.Name)
}

if tt.expectCachedAdd {
found := false
for _, m := range ps.Modules {
if m.Name == tt.expectName {
found = true
break
}
}
if !found {
t.Fatalf("expected module to be cached")
}
}
})
}
}
20 changes: 0 additions & 20 deletions pkg/symbolize/symbolizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ var (
// symModulesCount counts the number of loaded module exports
symModulesCount = expvar.NewInt("symbolizer.modules.count")

// symEnumModulesHits counts the number of hits from enumerated modules
symEnumModulesHits = expvar.NewInt("symbolizer.enum.modules.hits")

// debugHelpFallbacks counts how many times we Debug Help API was called
// to resolve symbol information since we fail to do this from process
// modules and PE export directory data
Expand Down Expand Up @@ -461,23 +458,6 @@ func (s *Symbolizer) produceFrame(addr va.Address, e *event.Event) callstack.Fra
if mod == nil && ps.Parent != nil {
mod = ps.Parent.FindModuleByVa(addr)
}
if mod == nil {
// our last resort is to enumerate process modules
modules := sys.EnumProcessModules(pid)
for _, m := range modules {
b := va.Address(m.BaseOfDll)
size := uint64(m.SizeOfImage)
if addr >= b && addr <= b.Inc(size) {
mod = &pstypes.Module{
Name: m.Name,
BaseAddress: b,
Size: size,
}
symEnumModulesHits.Add(1)
break
}
}
}

if mod != nil {
frame.Module = mod.Name
Expand Down
Loading