diff --git a/.github/workflows/golang.yaml b/.github/workflows/golang.yaml index 9460adfc6..d0cc81ba3 100644 --- a/.github/workflows/golang.yaml +++ b/.github/workflows/golang.yaml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - go-version: [ '1.19', '1.20', '1.21.x', '1.22.x', '1.23.x', '1.24.x', '1.25.x', '1.26.x' ] + go-version: [ '1.25.x', '1.26.x' ] os: [ ubuntu-latest, macos-latest ] runs-on: ${{ matrix.os }} steps: diff --git a/cli/src/walk.rs b/cli/src/walk.rs index f338e1482..674618895 100644 --- a/cli/src/walk.rs +++ b/cli/src/walk.rs @@ -475,25 +475,25 @@ impl<'a> ParWalker<'a> { ); let t_active = start_time.elapsed(); - if let Some(limit) = cpu_limit { - if limit < 100 { - // Calculate the required sleep duration to limit - // CPU usage to the target percentage. - // - // Let T_active be the elapsed time scanning the - // file. Let T_sleep be the sleep time. The target - // utilization percentage is P. - // - // P = 100 * T_active / (T_active + T_sleep) - // P * (T_active + T_sleep) = 100 * T_active - // P * T_sleep = (100 - P) * T_active - // T_sleep = T_active * (100 - P) / P - let t_sleep = t_active.mul_f64( - (100.0 - limit as f64) / limit as f64, - ); - if !t_sleep.is_zero() { - thread::sleep(t_sleep); - } + if let Some(limit) = cpu_limit + && limit < 100 + { + // Calculate the required sleep duration to limit + // CPU usage to the target percentage. + // + // Let T_active be the elapsed time scanning the + // file. Let T_sleep be the sleep time. The target + // utilization percentage is P. + // + // P = 100 * T_active / (T_active + T_sleep) + // P * (T_active + T_sleep) = 100 * T_active + // P * T_sleep = (100 - P) * T_active + // T_sleep = T_active * (100 - P) / P + let t_sleep = t_active.mul_f64( + (100.0 - limit as f64) / limit as f64, + ); + if !t_sleep.is_zero() { + thread::sleep(t_sleep); } } diff --git a/go/compiler.go b/go/compiler.go index 3467f24d3..9b61b7db3 100644 --- a/go/compiler.go +++ b/go/compiler.go @@ -2,6 +2,7 @@ package yara_x // #include import "C" + import ( "encoding/json" "errors" @@ -28,7 +29,7 @@ type CompileOption func(c *Compiler) error // // Valid value types include: int, int32, int64, bool, string, float32 and // float64. -func Globals(vars map[string]interface{}) CompileOption { +func Globals(vars map[string]any) CompileOption { return func(c *Compiler) error { for ident, value := range vars { c.vars[ident] = value @@ -289,7 +290,7 @@ type Compiler struct { includesEnabled bool ignoredModules map[string]bool bannedModules map[string]bannedModule - vars map[string]interface{} + vars map[string]any features []string includeDirs []string maxWarnings *int @@ -301,7 +302,7 @@ func NewCompiler(opts ...CompileOption) (*Compiler, error) { includesEnabled: true, ignoredModules: make(map[string]bool), bannedModules: make(map[string]bannedModule), - vars: make(map[string]interface{}), + vars: make(map[string]any), features: make([]string, 0), includeDirs: make([]string, 0), } @@ -558,7 +559,7 @@ func (c *Compiler) DefineGlobal(ident string, value interface{}) error { ret = C.int(C.yrx_compiler_define_global_float(c.cCompiler, cIdent, C.double(v))) case float64: ret = C.int(C.yrx_compiler_define_global_float(c.cCompiler, cIdent, C.double(v))) - case map[string]interface{}, []interface{}: + case map[string]any, []any: jsonStr, err := json.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal '%s' to json: '%v'", ident, err) diff --git a/go/compiler_test.go b/go/compiler_test.go index 3c85322b3..9f5e44de0 100644 --- a/go/compiler_test.go +++ b/go/compiler_test.go @@ -3,16 +3,16 @@ package yara_x import ( "bytes" "fmt" - "io/ioutil" "os" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNamespaces(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.NewNamespace("foo") c.AddSource("rule test { condition: true }") @@ -26,7 +26,7 @@ func TestNamespaces(t *testing.T) { func TestGlobals(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) x := map[string]any{"a": map[string]any{"a": "b"}, "b": "d"} y := []any{"z"} @@ -54,7 +54,7 @@ func TestUnsupportedModules(t *testing.T) { rule test { condition: true }`, IgnoreModule("unsupported_module")) - assert.NoError(t, err) + require.NoError(t, err) scanResults, _ := r.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) } @@ -85,8 +85,8 @@ func TestDisabledIncludes(t *testing.T) { } func TestIncludes(t *testing.T) { - file, err := ioutil.TempFile("", "prefix") - assert.NoError(t, err) + file, err := os.CreateTemp(t.TempDir(), "prefix") + require.NoError(t, err) defer os.Remove(file.Name()) @@ -94,14 +94,14 @@ func TestIncludes(t *testing.T) { fmt.Sprintf(`include "%s"`, file.Name()), IncludeDir(os.TempDir())) - assert.NoError(t, err) + require.NoError(t, err) } func TestRelaxedReSyntax(t *testing.T) { r, err := Compile(` rule test { strings: $a = /\Release/ condition: $a }`, RelaxedReSyntax(true)) - assert.NoError(t, err) + require.NoError(t, err) scanResults, _ := r.Scan([]byte("Release")) assert.Len(t, scanResults.MatchingRules(), 1) } @@ -110,7 +110,7 @@ func TestConditionOptimization(t *testing.T) { _, err := Compile(` rule test { condition: true }`, ConditionOptimization(true)) - assert.NoError(t, err) + require.NoError(t, err) } func TestErrorOnSlowPattern(t *testing.T) { @@ -129,14 +129,14 @@ func TestErrorOnSlowLoop(t *testing.T) { func TestSerialization(t *testing.T) { r, err := Compile("rule test { condition: true }") - assert.NoError(t, err) + require.NoError(t, err) var buf bytes.Buffer // Write rules into buffer n, err := r.WriteTo(&buf) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, buf.Bytes(), int(n)) // Read rules from buffer @@ -157,7 +157,7 @@ func TestVariables(t *testing.T) { assert.Len(t, scanResults.MatchingRules(), 1) c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.DefineGlobal("var", 1234) c.AddSource("rule test { condition: var == 1234 }") @@ -177,7 +177,7 @@ func TestVariables(t *testing.T) { c.DefineGlobal("var", false) c.AddSource("rule test { condition: var }") scanResults, _ = NewScanner(c.Build()).Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 0) + assert.Empty(t, scanResults.MatchingRules()) c.DefineGlobal("var", "foo") c.AddSource("rule test { condition: var == \"foo\" }") @@ -207,26 +207,26 @@ func TestCompilerFeatures(t *testing.T) { rules := `import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar }` _, err := Compile(rules) - assert.EqualError(t, err, `error[E100]: foo is required + require.EqualError(t, err, `error[E100]: foo is required --> line:1:57 | 1 | import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar } | ^^^^^^^^^^^^^^^^^^^^ this field was used without foo`) _, err = Compile(rules, WithFeature("foo")) - assert.EqualError(t, err, `error[E100]: bar is required + require.EqualError(t, err, `error[E100]: bar is required --> line:1:57 | 1 | import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar } | ^^^^^^^^^^^^^^^^^^^^ this field was used without bar`) _, err = Compile(rules, WithFeature("foo"), WithFeature("bar")) - assert.NoError(t, err) + require.NoError(t, err) } func TestErrors(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource("rule test_1 { condition: true }") assert.Equal(t, []CompileError{}, c.Errors()) @@ -291,13 +291,13 @@ func TestErrors(t *testing.T) { func TestRules(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource(`rule test_1 : tag1 tag2 { condition: true }`) - assert.NoError(t, err) + require.NoError(t, err) c.AddSource(`rule test_2 { meta: @@ -308,7 +308,7 @@ func TestRules(t *testing.T) { condition: true }`) - assert.NoError(t, err) + require.NoError(t, err) rules := c.Build() assert.Equal(t, 2, rules.Count()) @@ -324,7 +324,7 @@ func TestRules(t *testing.T) { assert.Equal(t, []string{"tag1", "tag2"}, slice[0].Tags()) assert.Equal(t, []string{}, slice[1].Tags()) - assert.Len(t, slice[0].Metadata(), 0) + assert.Empty(t, slice[0].Metadata()) assert.Len(t, slice[1].Metadata(), 4) assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier()) @@ -342,7 +342,7 @@ func TestRules(t *testing.T) { func TestImportsIter(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource(` import "pe" @@ -351,7 +351,7 @@ func TestImportsIter(t *testing.T) { condition: true }`) - assert.NoError(t, err) + require.NoError(t, err) rules := c.Build() imports := rules.Imports() @@ -363,7 +363,7 @@ func TestImportsIter(t *testing.T) { func TestWarnings(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource("rule test { strings: $a = {01 [0-1][0-1] 02 } condition: $a }") diff --git a/go/example_test.go b/go/example_test.go index 735d65830..bf58e9960 100644 --- a/go/example_test.go +++ b/go/example_test.go @@ -50,7 +50,6 @@ func Example_compilerAndScanner() { condition: $bar }`) - if err != nil { panic(err) } diff --git a/go/go.mod b/go/go.mod index c622d42b6..1e53b916b 100644 --- a/go/go.mod +++ b/go/go.mod @@ -1,6 +1,6 @@ module github.com/VirusTotal/yara-x/go -go 1.18 +go 1.25 require ( github.com/stretchr/testify v1.8.4 diff --git a/go/go.sum b/go/go.sum index 533efc0fc..2711800e8 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,11 +1,13 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/go/main.go b/go/main.go index 1d77182db..20e18baa7 100644 --- a/go/main.go +++ b/go/main.go @@ -266,7 +266,7 @@ type Pattern struct { // Metadata represents a metadata in a Rule. type Metadata struct { identifier string - value interface{} + value any } // Match contains information about the offset where a match occurred and @@ -358,7 +358,7 @@ func (m *Metadata) Identifier() string { } // Value associated to the metadata. -func (m *Metadata) Value() interface{} { +func (m *Metadata) Value() any { return m.value } @@ -474,7 +474,7 @@ func metadataCallback(metadata *C.YRX_METADATA, handle C.uintptr_t) { panic("matchCallback didn't receive a *[]Metadata") } - var value interface{} + var value any switch metadata.value_type { case C.YRX_I64: diff --git a/go/scanner.go b/go/scanner.go index c1926c56e..67c08ad8f 100644 --- a/go/scanner.go +++ b/go/scanner.go @@ -31,9 +31,7 @@ import ( "runtime/cgo" "time" "unsafe" -) -import ( "google.golang.org/protobuf/proto" ) @@ -163,7 +161,7 @@ func (s *Scanner) SetGlobal(ident string, value interface{}) error { ret = C.int(C.yrx_scanner_set_global_str(s.cScanner, cIdent, cValue)) case float64: ret = C.int(C.yrx_scanner_set_global_float(s.cScanner, cIdent, C.double(v))) - case map[string]interface{}, []interface{}: + case map[string]any, []any: jsonStr, err := json.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal '%s' to json: '%v'", ident, err) @@ -316,7 +314,8 @@ func slowestRulesCallback( rule *C.char, patternMatchingTime C.double, conditionExecTime C.double, - handle C.uintptr_t) { + handle C.uintptr_t, +) { h := cgo.Handle(handle) profilingInfo, ok := h.Value().(*[]ProfilingInfo) if !ok { @@ -358,17 +357,17 @@ func (s *Scanner) SlowestRules(n int) []ProfilingInfo { return profilingInfo } -/// ClearProfilingData resets the profiling data collected during rule execution -/// across scanned files. Use it to start a new profiling session, ensuring the -/// results reflect only the data gathered after this method is called. +// ClearProfilingData resets the profiling data collected during rule execution +// across scanned files. Use it to start a new profiling session, ensuring the +// results reflect only the data gathered after this method is called. // // In order to use this function, the YARA-X C library must be built with // support for rules profiling by enabling the `rules-profiling` feature. // Otherwise, calling this function will cause a panic. func (s *Scanner) ClearProfilingData() { - if C.yrx_scanner_clear_profiling_data(s.cScanner) == C.YRX_NOT_SUPPORTED { - panic("ClearProfilingData requires that the YARA-X C library is built with the `rules-profiling` feature") - } + if C.yrx_scanner_clear_profiling_data(s.cScanner) == C.YRX_NOT_SUPPORTED { + panic("ClearProfilingData requires that the YARA-X C library is built with the `rules-profiling` feature") + } } // Destroy destroys the scanner. diff --git a/go/scanner_test.go b/go/scanner_test.go index ba8aa3286..49dd363cc 100644 --- a/go/scanner_test.go +++ b/go/scanner_test.go @@ -2,11 +2,13 @@ package yara_x import ( "bytes" - "github.com/stretchr/testify/assert" "os" "runtime" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestScanner1(t *testing.T) { @@ -18,7 +20,7 @@ func TestScanner1(t *testing.T) { assert.Len(t, matchingRules, 1) assert.Equal(t, "t", matchingRules[0].Identifier()) assert.Equal(t, "default", matchingRules[0].Namespace()) - assert.Len(t, matchingRules[0].Patterns(), 0) + assert.Empty(t, matchingRules[0].Patterns()) scanResults, _ = s.Scan(nil) matchingRules = scanResults.MatchingRules() @@ -26,7 +28,7 @@ func TestScanner1(t *testing.T) { assert.Len(t, matchingRules, 1) assert.Equal(t, "t", matchingRules[0].Identifier()) assert.Equal(t, "default", matchingRules[0].Namespace()) - assert.Len(t, matchingRules[0].Patterns(), 0) + assert.Empty(t, matchingRules[0].Patterns()) } func TestScanner2(t *testing.T) { @@ -59,7 +61,7 @@ func TestScanner3(t *testing.T) { s.SetGlobal("var_bool", false) scanResults, _ = s.Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 0) + assert.Empty(t, scanResults.MatchingRules()) } func TestScanner4(t *testing.T) { @@ -69,17 +71,17 @@ func TestScanner4(t *testing.T) { s := NewScanner(r) scanResults, _ := s.Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 0) + assert.Empty(t, scanResults.MatchingRules()) - assert.NoError(t, s.SetGlobal("var_int", 1)) + require.NoError(t, s.SetGlobal("var_int", 1)) scanResults, _ = s.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) - assert.NoError(t, s.SetGlobal("var_int", int32(1))) + require.NoError(t, s.SetGlobal("var_int", int32(1))) scanResults, _ = s.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) - assert.NoError(t, s.SetGlobal("var_int", int64(1))) + require.NoError(t, s.SetGlobal("var_int", int64(1))) scanResults, _ = s.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) } @@ -89,12 +91,12 @@ func TestScanFile(t *testing.T) { s := NewScanner(r) // Create a temporary file with some content - f, err := os.CreateTemp("", "example") - assert.NoError(t, err) + f, err := os.CreateTemp(t.TempDir(), "example") + require.NoError(t, err) defer os.Remove(f.Name()) _, err = f.Write([]byte("foobar")) - assert.NoError(t, err) + require.NoError(t, err) f.Close() scanResults, _ := s.ScanFile(f.Name()) diff --git a/lib/src/compiler/ir/ast2ir.rs b/lib/src/compiler/ir/ast2ir.rs index 2bfdf3ccd..bc5371f8f 100644 --- a/lib/src/compiler/ir/ast2ir.rs +++ b/lib/src/compiler/ir/ast2ir.rs @@ -31,8 +31,8 @@ use crate::compiler::ir::{ }; use crate::compiler::report::{Level, ReportBuilder}; use crate::compiler::{ - CompileContext, CompileError, FilesizeBounds, ForVars, PatternIdx, - RegexId, RegexSetId, TextPatternAsHex, warnings, + CompileContext, CompileError, FilesizeBounds, ForVars, HeaderConstraint, + PatternIdx, RegexId, RegexSetId, TextPatternAsHex, warnings, }; use crate::errors::CustomError; use crate::errors::{MethodNotAllowedInWith, PotentiallySlowLoop}; @@ -258,6 +258,7 @@ pub(in crate::compiler) fn text_pattern_from_ast<'src>( base64wide_alphabet, anchored_at: None, filesize_bounds: FilesizeBounds::default(), + header_constraints: HeaderConstraint::default(), }), }) } @@ -312,6 +313,7 @@ pub(in crate::compiler) fn hex_pattern_from_ast<'src>( flags: PatternFlags::Ascii, anchored_at: None, filesize_bounds: FilesizeBounds::default(), + header_constraints: HeaderConstraint::default(), }), }) } @@ -448,6 +450,7 @@ pub(in crate::compiler) fn regexp_pattern_from_ast<'src>( hir, anchored_at: None, filesize_bounds: FilesizeBounds::default(), + header_constraints: HeaderConstraint::default(), }), }) } diff --git a/lib/src/compiler/ir/mod.rs b/lib/src/compiler/ir/mod.rs index 5dadd0423..a5dde92a3 100644 --- a/lib/src/compiler/ir/mod.rs +++ b/lib/src/compiler/ir/mod.rs @@ -29,7 +29,8 @@ allows using the same regex engine for matching both types of patterns. [Hir]: regex_syntax::hir::Hir */ -use std::collections::Bound; +use std::collections::btree_map::Entry; +use std::collections::{BTreeMap, Bound}; use std::fmt::{Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::mem; @@ -51,7 +52,7 @@ use crate::compiler::ir::dfs::{ DFSIter, DFSWithScopeIter, Event, EventContext, dfs_common, }; -use crate::compiler::{FilesizeBounds, RegexSetId}; +use crate::compiler::{FilesizeBounds, HeaderConstraint, RegexSetId}; use crate::re; use crate::symbols::Symbol; use crate::types::Value::Const; @@ -310,6 +311,17 @@ impl Pattern { } } } + + pub fn set_header_constraints(&mut self, constraints: &HeaderConstraint) { + match self { + Pattern::Text(literal) => { + literal.header_constraints = constraints.clone(); + } + Pattern::Regexp(regexp) | Pattern::Hex(regexp) => { + regexp.header_constraints = constraints.clone(); + } + } + } } #[derive(Clone, Eq, Hash, PartialEq)] @@ -321,6 +333,7 @@ pub(crate) struct LiteralPattern { pub base64_alphabet: Option, pub base64wide_alphabet: Option, pub filesize_bounds: FilesizeBounds, + pub header_constraints: HeaderConstraint, } #[derive(Clone, Eq, Hash, PartialEq)] @@ -329,6 +342,7 @@ pub(crate) struct RegexpPattern { pub hir: re::hir::Hir, pub anchored_at: Option, pub filesize_bounds: FilesizeBounds, + pub header_constraints: HeaderConstraint, } /// The index of a pattern in the rule that declares it. @@ -992,6 +1006,251 @@ impl IR { result } + + pub fn header_constraints( + &self, + pattern_prefix_lookup: impl Fn(PatternIdx) -> Option>, + ) -> HeaderConstraint { + let mut constrained_bytes = BTreeMap::new(); + let mut unsatisfiable = false; + let mut dfs = self.dfs_iter(self.root.unwrap()); + + while let Some(evt) = dfs.next() { + let expr = match evt { + Event::Enter((_, expr, _)) => expr, + _ => continue, + }; + match expr { + Expr::Eq { lhs, rhs } => { + self.extract_header_constraints_from_eq( + *lhs, + *rhs, + &mut constrained_bytes, + &mut unsatisfiable, + ); + } + Expr::PatternMatch { pattern, anchor } => { + if let MatchAnchor::At(offset_expr) = anchor + && let Some(0) = + self.get(*offset_expr).try_as_const_integer() + && let Some(prefix_bytes) = + pattern_prefix_lookup(*pattern) + { + for (i, &b) in prefix_bytes.iter().enumerate() { + match constrained_bytes.entry(i) { + Entry::Occupied(entry) => { + if *entry.get() != b { + unsatisfiable = true; + break; + } + } + Entry::Vacant(entry) => { + entry.insert(b); + } + } + } + } + } + _ => {} + } + if unsatisfiable { + break; + } + if !matches!(expr, Expr::And { .. }) { + dfs.prune(); + } + } + + if unsatisfiable { + return HeaderConstraint::Unsatisfiable; + } + + // If the first byte in `constrained_bytes` is at offset 0, we can + // return HeaderConstraint::Constrained. + if let Some((0, _)) = constrained_bytes.first_key_value() { + HeaderConstraint::Constrained( + // Take only the bytes at consecutive offsets starting at 0. + constrained_bytes + .into_iter() + .enumerate() + .map_while( + |(i, (offset, byte))| { + if i == offset { Some(byte) } else { None } + }, + ) + .collect(), + ) + } else { + HeaderConstraint::Unconstrained + } + } + + fn extract_header_constraints_from_eq( + &self, + lhs: ExprId, + rhs: ExprId, + constrained_bytes: &mut BTreeMap, + unsatisfiable: &mut bool, + ) { + if let Some(val) = self.get(rhs).try_as_const_integer() + && self.apply_int_read_constraint( + constrained_bytes, + unsatisfiable, + lhs, + val, + ) + { + return; + } + if let Some(val) = self.get(lhs).try_as_const_integer() { + self.apply_int_read_constraint( + constrained_bytes, + unsatisfiable, + rhs, + val, + ); + } + } + + fn add_constraint( + &self, + constrained_bytes: &mut BTreeMap, + unsatisfiable: &mut bool, + offset: usize, + value: u8, + ) { + if *unsatisfiable { + return; + } + match constrained_bytes.entry(offset) { + Entry::Occupied(entry) => { + if *entry.get() != value { + *unsatisfiable = true; + } + } + Entry::Vacant(entry) => { + entry.insert(value); + } + } + } + + fn apply_int_read_constraint( + &self, + constrained_bytes: &mut BTreeMap, + unsatisfiable: &mut bool, + expr_id: ExprId, + val: i64, + ) -> bool { + let func_call = match self.get(expr_id) { + Expr::FuncCall(func_call) => func_call, + _ => return false, + }; + + if let Some(offset) = func_call + .args + .first() + .and_then(|arg| self.get(*arg).try_as_const_integer()) + && offset >= 0 + { + match func_call.plain_name() { + "uint8" | "int8" | "uint8be" | "int8be" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + val as u8, + ); + return true; + } + "uint16" | "int16" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + (val as u16 & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + ((val as u16 >> 8) & 0xff) as u8, + ); + return true; + } + "uint16be" | "int16be" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + ((val as u16 >> 8) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + (val as u16 & 0xff) as u8, + ); + return true; + } + "uint32" | "int32" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + (val as u32 & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + ((val as u32 >> 8) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 2, + ((val as u32 >> 16) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 3, + ((val as u32 >> 24) & 0xff) as u8, + ); + return true; + } + "uint32be" | "int32be" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + ((val as u32 >> 24) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + ((val as u32 >> 16) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 2, + ((val as u32 >> 8) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 3, + (val as u32 & 0xff) as u8, + ); + return true; + } + _ => {} + } + } + false + } } impl IR { @@ -2367,6 +2626,12 @@ impl FuncCall { pub fn mangled_name(&self) -> &str { self.signature().mangled_name.as_str() } + + /// Returns the plain function name, without argument or return type + /// information (i.e: everything before the `@` in the name). + pub fn plain_name(&self) -> &str { + self.signature().mangled_name.plain_name() + } } /// An `of` expression with a tuple of expressions (e.g. `1 of (true, false)`). diff --git a/lib/src/compiler/mod.rs b/lib/src/compiler/mod.rs index 9f49a9c06..182a41f42 100644 --- a/lib/src/compiler/mod.rs +++ b/lib/src/compiler/mod.rs @@ -337,6 +337,19 @@ pub struct Compiler<'a> { /// `FilesizeBounds{start: Bound::Unbounded, end: Bound:Excluded(1000)}`. filesize_bounds: FxHashMap, + /// Map that associates a `PatternId` to a certain constraint on the + /// file header (e.g. magic bytes at offset 0), if any. + /// + /// A condition like `uint16(0) == 0x5A4D and $a` or `$mz at 0 and $a` + /// (were $mz = "MZ") only matches if the file starts with "MZ" (0x5A4D). + /// In this case the map will contain an entry associating `$a` to a + /// `HeaderConstraint` that requires the file to start with those two + /// bytes. + /// + /// This allows skipping pattern checks entirely if the scanned data doesn't + /// start with the expected header prefix. + header_constraints: FxHashMap, + /// A vector with all the rules that has been compiled. A [`RuleId`] is /// an index in this vector. rules: Vec, @@ -502,6 +515,7 @@ impl<'a> Compiler<'a> { banned_modules: FxHashMap::default(), ignored_rules: FxHashMap::default(), filesize_bounds: FxHashMap::default(), + header_constraints: FxHashMap::default(), root_struct: Struct::new().make_root(), report_builder: ReportBuilder::new(), lit_pool: BStringPool::new(), @@ -813,6 +827,7 @@ impl<'a> Compiler<'a> { re_code: self.re_code, warnings: self.warnings.into(), filesize_bounds: self.filesize_bounds, + header_constraints: self.header_constraints, regex_sets: self.regex_sets, fast_scan_patterns: self.fast_scan_patterns, }; @@ -1668,6 +1683,18 @@ impl Compiler<'_> { // `filesize`, if any. let filesize_bounds = self.ir.filesize_bounds(); + // Analyze the condition and determine if it imposes some constraint + // to the file header (ex: `uint16(0) == 0x5a4d`). + let header_constraints = self.ir.header_constraints(|pat_idx| { + let pat = &rule_patterns[pat_idx.as_usize()]; + match pat.pattern() { + Pattern::Text(lit) => Some(lit.text.as_bytes().to_vec()), + Pattern::Regexp(re) | Pattern::Hex(re) => { + re.hir.as_literal_bytes().map(|bytes| bytes.to_vec()) + } + } + }); + // Set the bounds to all patterns in the rule. This must be done // before assigning the PatternId to each pattern, as the filesize // bounds are taken into account when determining if the pattern @@ -1678,6 +1705,15 @@ impl Compiler<'_> { } } + // Set header constraints to all patterns in the rule. + if !header_constraints.unconstrained() { + for pattern in &mut rule_patterns { + pattern + .pattern_mut() + .set_header_constraints(&header_constraints); + } + } + if let Some(w) = &mut self.ir_writer { writeln!(w, "RULE {}", rule.identifier.name).unwrap(); writeln!(w, "{:?}", self.ir).unwrap(); @@ -1810,6 +1846,17 @@ impl Compiler<'_> { "modifying the file size bounds of an existing pattern" ) } + if !header_constraints.unconstrained() + && self + .header_constraints + .insert(*pattern_id, header_constraints.clone()) + .is_some() + { + // This should not happen. + panic!( + "modifying the header constraints of an existing pattern" + ) + } pending_patterns.remove(pattern_id); } } diff --git a/lib/src/compiler/rules.rs b/lib/src/compiler/rules.rs index 97232d6c0..16a6fafe1 100644 --- a/lib/src/compiler/rules.rs +++ b/lib/src/compiler/rules.rs @@ -119,6 +119,9 @@ pub struct Rules { pub(in crate::compiler) filesize_bounds: FxHashMap, + pub(in crate::compiler) header_constraints: + FxHashMap, + /// Vector that contains the [`SubPatternId`] for sub-patterns that can /// match only at a fixed offset within the scanned data. These sub-patterns /// are not added to the Aho-Corasick automaton. @@ -577,6 +580,14 @@ impl Rules { self.filesize_bounds.get(&pattern_id) } + #[inline] + pub(crate) fn header_constraints( + &self, + pattern_id: PatternId, + ) -> Option<&HeaderConstraint> { + self.header_constraints.get(&pattern_id) + } + #[inline] pub(crate) fn is_fast_scan(&self, pattern_id: PatternId) -> bool { *self.fast_scan_patterns.get(usize::from(pattern_id)).unwrap() @@ -851,6 +862,34 @@ impl FilesizeBounds { } } +/// Describes the requirements on the file header imposed by a rule condition. +/// +/// For example, the condition `uint32(0) == 0x464c457f` requires that the first +/// 4 bytes of the file are `0x7f, 0x45, 0x4c, 0x46`. +#[derive( + Debug, PartialEq, Serialize, Deserialize, Clone, Hash, Eq, Default, +)] +pub(crate) enum HeaderConstraint { + #[default] + Unconstrained, + Unsatisfiable, + Constrained(Vec), +} + +impl HeaderConstraint { + pub fn unconstrained(&self) -> bool { + matches!(self, Self::Unconstrained) + } + + pub fn is_satisfied(&self, data: &[u8]) -> bool { + match self { + Self::Unconstrained => true, + Self::Unsatisfiable => false, + Self::Constrained(bytes) => data.starts_with(bytes), + } + } +} + /// Represents an atom extracted from a pattern and added to the Aho-Corasick /// automata. /// diff --git a/lib/src/scanner/context.rs b/lib/src/scanner/context.rs index 7e3786502..f769cbbcb 100644 --- a/lib/src/scanner/context.rs +++ b/lib/src/scanner/context.rs @@ -171,6 +171,9 @@ pub struct ScanContext<'r, 'd> { pub(crate) console_log: Option>, /// Virtual Machines used for executing regexps. pub(crate) vm: VM<'r>, + /// Patterns that are disabled for the current scan (e.g. because they don't + /// comply with filesize bounds or header constraints). + pub(crate) disabled_patterns: FxHashSet, /// Hash map that tracks the time spend on each pattern. Keys are pattern /// PatternIds and values are the cumulative time spent on verifying each /// pattern. @@ -541,6 +544,9 @@ impl ScanContext<'_, '_> { // Free all runtime objects left around by previous scans. self.runtime_objects.clear(); + // Clear the set that tracks the disabled patterns. + self.disabled_patterns.clear(); + // Clear the array that tracks the patterns that reached the maximum // number of patterns. self.tracker.limit_reached.clear(); @@ -764,10 +770,8 @@ impl ScanContext<'_, '_> { &mut self, base: usize, data: &[u8], - block_scanning_mode: bool, ) -> Result<(), ScanError> { let ac = self.compiled_rules.ac_automaton(); - let filesize = self.get_filesize(); #[cfg(feature = "logging")] let mut atom_matches = 0_usize; @@ -792,8 +796,6 @@ impl ScanContext<'_, '_> { match_offset, base, data, - filesize, - block_scanning_mode, ); }); } @@ -814,8 +816,6 @@ impl ScanContext<'_, '_> { ac_match.start(), base, data, - filesize, - block_scanning_mode, ); } } @@ -834,8 +834,6 @@ impl ScanContext<'_, '_> { match_start: usize, base: usize, data: &[u8], - filesize: i64, - block_scanning_mode: bool, ) { let atoms = self.compiled_rules.atoms(); let atom = unsafe { atoms.get_unchecked(atom_idx) }; @@ -855,15 +853,11 @@ impl ScanContext<'_, '_> { let (pattern_id, sub_pattern) = &self.compiled_rules.get_sub_pattern(sub_pattern_id); - if self.tracker.limit_reached.contains(pattern_id) { + if self.disabled_patterns.contains(pattern_id) { return; } - if !block_scanning_mode - && let Some(bounds) = - self.compiled_rules.filesize_bounds(*pattern_id) - && !bounds.contains(filesize) - { + if self.tracker.limit_reached.contains(pattern_id) { return; } @@ -1073,6 +1067,31 @@ impl ScanContext<'_, '_> { _ => panic!(), }; + if !block_scanning_mode { + let filesize = self.get_filesize(); + for pattern_id in 0..self.compiled_rules.num_patterns() { + let pattern_id = PatternId::from(pattern_id); + if let Some(bounds) = + self.compiled_rules.filesize_bounds(pattern_id) + && !bounds.contains(filesize) + { + self.disabled_patterns.insert(pattern_id); + } + } + } + + if base == 0 { + for pattern_id in 0..self.compiled_rules.num_patterns() { + let pattern_id = PatternId::from(pattern_id); + if let Some(constraints) = + self.compiled_rules.header_constraints(pattern_id) + && !constraints.is_satisfied(data) + { + self.disabled_patterns.insert(pattern_id); + } + } + } + #[cfg(any(feature = "rules-profiling", feature = "logging"))] let scan_start = self.clock.raw(); @@ -1080,8 +1099,7 @@ impl ScanContext<'_, '_> { // match at a single known offset within the data. self.verify_anchored_patterns(base, data); - let result = match self.ac_search_loop(base, data, block_scanning_mode) - { + let result = match self.ac_search_loop(base, data) { Ok(_) => { self.scan_state = state; Ok(()) @@ -1128,6 +1146,10 @@ impl ScanContext<'_, '_> { .anchored_sub_patterns() .iter() .map(|id| (id, self.compiled_rules.get_sub_pattern(*id))) + // Disabled patterns are ignored. + .filter(|(_, (pattern_id, _))| { + !self.disabled_patterns.contains(pattern_id) + }) { match sub_pattern { SubPattern::Literal { @@ -1924,6 +1946,7 @@ pub fn create_wasm_store_and_ctx<'r>( pike_vm: PikeVM::new(rules.re_code()), fast_vm: FastVM::new(rules.re_code()), }, + disabled_patterns: FxHashSet::default(), custom_base64_engine_cache: Vec::new(), #[cfg(feature = "rules-profiling")] time_spent_in_pattern: FxHashMap::default(), diff --git a/lib/src/tests/mod.rs b/lib/src/tests/mod.rs index bc5aa2206..b5c75e1ec 100644 --- a/lib/src/tests/mod.rs +++ b/lib/src/tests/mod.rs @@ -3948,3 +3948,159 @@ fn short_circuit() { b"foobar" ); } + +#[test] +fn header_constraints_optimization() { + // ELF magic check. + rule_true!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and $a + } + "#, + b"\x7fELF\0\0\0\0" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and $a + } + "#, + b"\0\0\0\0ELF" + ); + + // PE magic check. + rule_true!( + r#" + rule test { + strings: + $a = "PE" + condition: + uint16(0) == 0x5a4d and $a + } + "#, + b"MZ\0\0PE" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "PE" + condition: + uint16(0) == 0x5a4d and $a + } + "#, + b"\0\0MZPE" + ); + + // Pattern at 0 check. + rule_true!( + r#" + rule test { + strings: + $a = "MZ" + condition: + $a at 0 + } + "#, + b"MZ" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "MZ" + condition: + $a at 0 + } + "#, + b"\0MZ" + ); + + // Multiple constraints combined. + rule_true!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and uint16(4) == 0x0102 and $a + } + "#, + b"\x7fELF\x02\x01" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and uint16(4) == 0x0102 and $a + } + "#, + b"\x7fELF\x99\x99" + ); + + // Deduplication test: A pattern used in both a constrained and an unconstrained rule. + // When the header does not match, only the unconstrained rule should match (exactly 1 rule). + rule_true!( + r#" + rule constrained { + strings: + $a = "foo" + condition: + uint32(0) == 0x464c457f and $a + } + rule unconstrained { + strings: + $a = "foo" + condition: + $a + } + "#, + b"\0\0\0\0foo" + ); + + // When the header matches, both rules should match (exactly 2 rules). + test_rule!( + r#" + rule constrained { + strings: + $a = "foo" + condition: + uint32(0) == 0x464c457f and $a + } + rule unconstrained { + strings: + $a = "foo" + condition: + $a + } + "#, + b"\x7fELFfoo", + 2 + ); + + // Non-contiguous offsets for uint8. + rule_true!( + r#" + rule constrained { + strings: + $a = "MZ" + condition: + uint8(0) == 0x00 and uint8(2) == 0x5a and $a + } + "#, + b"\0MZ" + ); +} diff --git a/lib/src/types/func.rs b/lib/src/types/func.rs index 100615790..a62444238 100644 --- a/lib/src/types/func.rs +++ b/lib/src/types/func.rs @@ -86,9 +86,13 @@ impl MangledFnName { pub fn as_str(&self) -> &str { self.0.as_str() } -} -impl MangledFnName { + /// Returns the plain function name, without argument or return type + /// information (i.e: everything before the `@` in the name). + pub fn plain_name(&self) -> &str { + self.0.as_str().split("@").next().unwrap() + } + /// Returns the types of arguments and return value for the function. pub fn unmangle(&self) -> (Vec<(&str, TypeValue)>, TypeValue) { let (_fn_name, arg_names_and_types, ret_type) =