From c4bb16b4274cdce5cfffcdeabc751873ddee2791 Mon Sep 17 00:00:00 2001 From: "Victor M. Alvarez" Date: Mon, 8 Jun 2026 11:24:25 +0200 Subject: [PATCH 1/5] style: fix clippy warning. --- cli/src/walk.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cli/src/walk.rs b/cli/src/walk.rs index f338e1482..e53e50e56 100644 --- a/cli/src/walk.rs +++ b/cli/src/walk.rs @@ -475,8 +475,8 @@ impl<'a> ParWalker<'a> { ); let t_active = start_time.elapsed(); - if let Some(limit) = cpu_limit { - if limit < 100 { + if let Some(limit) = cpu_limit + && limit < 100 { // Calculate the required sleep duration to limit // CPU usage to the target percentage. // @@ -495,7 +495,6 @@ impl<'a> ParWalker<'a> { thread::sleep(t_sleep); } } - } if let Err(err) = res && error(err, &msg_send).is_err() From 868069cb2d0ed820b8358d31be23fdfca05a2589 Mon Sep 17 00:00:00 2001 From: "Victor M. Alvarez" Date: Mon, 8 Jun 2026 11:27:47 +0200 Subject: [PATCH 2/5] style: run `cargo fmt`. --- cli/src/walk.rs | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/cli/src/walk.rs b/cli/src/walk.rs index e53e50e56..674618895 100644 --- a/cli/src/walk.rs +++ b/cli/src/walk.rs @@ -476,25 +476,26 @@ impl<'a> ParWalker<'a> { let t_active = start_time.elapsed(); if let Some(limit) = cpu_limit - && limit < 100 { - // Calculate the required sleep duration to limit - // CPU usage to the target percentage. - // - // Let T_active be the elapsed time scanning the - // file. Let T_sleep be the sleep time. The target - // utilization percentage is P. - // - // P = 100 * T_active / (T_active + T_sleep) - // P * (T_active + T_sleep) = 100 * T_active - // P * T_sleep = (100 - P) * T_active - // T_sleep = T_active * (100 - P) / P - let t_sleep = t_active.mul_f64( - (100.0 - limit as f64) / limit as f64, - ); - if !t_sleep.is_zero() { - thread::sleep(t_sleep); - } + && limit < 100 + { + // Calculate the required sleep duration to limit + // CPU usage to the target percentage. + // + // Let T_active be the elapsed time scanning the + // file. Let T_sleep be the sleep time. The target + // utilization percentage is P. + // + // P = 100 * T_active / (T_active + T_sleep) + // P * (T_active + T_sleep) = 100 * T_active + // P * T_sleep = (100 - P) * T_active + // T_sleep = T_active * (100 - P) / P + let t_sleep = t_active.mul_f64( + (100.0 - limit as f64) / limit as f64, + ); + if !t_sleep.is_zero() { + thread::sleep(t_sleep); } + } if let Err(err) = res && error(err, &msg_send).is_err() From ca7b8e7901aa9aa96bc08efa274d8779a36918e4 Mon Sep 17 00:00:00 2001 From: zdiff Date: Mon, 8 Jun 2026 05:40:45 -0400 Subject: [PATCH 3/5] chore: drop support for Go versions below 1.25 (#675) This change may be overreaching and I understand if it is closed. Each major Go release is supported until there are two newer major releases. Go 1.24 support ended on February 4th, 2026. --- .github/workflows/golang.yaml | 2 +- go/go.mod | 2 +- go/go.sum | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/golang.yaml b/.github/workflows/golang.yaml index 9460adfc6..d0cc81ba3 100644 --- a/.github/workflows/golang.yaml +++ b/.github/workflows/golang.yaml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - go-version: [ '1.19', '1.20', '1.21.x', '1.22.x', '1.23.x', '1.24.x', '1.25.x', '1.26.x' ] + go-version: [ '1.25.x', '1.26.x' ] os: [ ubuntu-latest, macos-latest ] runs-on: ${{ matrix.os }} steps: diff --git a/go/go.mod b/go/go.mod index c622d42b6..1e53b916b 100644 --- a/go/go.mod +++ b/go/go.mod @@ -1,6 +1,6 @@ module github.com/VirusTotal/yara-x/go -go 1.18 +go 1.25 require ( github.com/stretchr/testify v1.8.4 diff --git a/go/go.sum b/go/go.sum index 533efc0fc..2711800e8 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,11 +1,13 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= From b0d3360bd1bc72eb6c9e5c9146240a99dd108c78 Mon Sep 17 00:00:00 2001 From: "Victor M. Alvarez" Date: Mon, 8 Jun 2026 11:54:38 +0200 Subject: [PATCH 4/5] perf: optimize scans using header constraints (#676) This change introduces a mechanism to detect and leverage "header constraints" derived from YARA rule conditions. These constraints, such as `uint32(0) == 0x464c457f` or `$a at 0`, specify required byte sequences or integer values at fixed offsets from the start of the file. During scanning, if the initial bytes of the data do not satisfy a pattern's header constraints, that pattern is disabled before any detailed matching attempts. This early pruning reduces redundant work and improves scan performance, particularly for files with well-defined magic bytes or headers that don't match specific rule conditions. The existing filesize bounds checks are also refactored to use this new pattern disabling mechanism for consistency and clearer logic. --- lib/src/compiler/ir/ast2ir.rs | 7 +- lib/src/compiler/ir/mod.rs | 269 +++++++++++++++++++++++++++++++++- lib/src/compiler/mod.rs | 47 ++++++ lib/src/compiler/rules.rs | 39 +++++ lib/src/scanner/context.rs | 55 +++++-- lib/src/tests/mod.rs | 156 ++++++++++++++++++++ lib/src/types/func.rs | 8 +- 7 files changed, 559 insertions(+), 22 deletions(-) diff --git a/lib/src/compiler/ir/ast2ir.rs b/lib/src/compiler/ir/ast2ir.rs index 2bfdf3ccd..bc5371f8f 100644 --- a/lib/src/compiler/ir/ast2ir.rs +++ b/lib/src/compiler/ir/ast2ir.rs @@ -31,8 +31,8 @@ use crate::compiler::ir::{ }; use crate::compiler::report::{Level, ReportBuilder}; use crate::compiler::{ - CompileContext, CompileError, FilesizeBounds, ForVars, PatternIdx, - RegexId, RegexSetId, TextPatternAsHex, warnings, + CompileContext, CompileError, FilesizeBounds, ForVars, HeaderConstraint, + PatternIdx, RegexId, RegexSetId, TextPatternAsHex, warnings, }; use crate::errors::CustomError; use crate::errors::{MethodNotAllowedInWith, PotentiallySlowLoop}; @@ -258,6 +258,7 @@ pub(in crate::compiler) fn text_pattern_from_ast<'src>( base64wide_alphabet, anchored_at: None, filesize_bounds: FilesizeBounds::default(), + header_constraints: HeaderConstraint::default(), }), }) } @@ -312,6 +313,7 @@ pub(in crate::compiler) fn hex_pattern_from_ast<'src>( flags: PatternFlags::Ascii, anchored_at: None, filesize_bounds: FilesizeBounds::default(), + header_constraints: HeaderConstraint::default(), }), }) } @@ -448,6 +450,7 @@ pub(in crate::compiler) fn regexp_pattern_from_ast<'src>( hir, anchored_at: None, filesize_bounds: FilesizeBounds::default(), + header_constraints: HeaderConstraint::default(), }), }) } diff --git a/lib/src/compiler/ir/mod.rs b/lib/src/compiler/ir/mod.rs index 5dadd0423..a5dde92a3 100644 --- a/lib/src/compiler/ir/mod.rs +++ b/lib/src/compiler/ir/mod.rs @@ -29,7 +29,8 @@ allows using the same regex engine for matching both types of patterns. [Hir]: regex_syntax::hir::Hir */ -use std::collections::Bound; +use std::collections::btree_map::Entry; +use std::collections::{BTreeMap, Bound}; use std::fmt::{Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::mem; @@ -51,7 +52,7 @@ use crate::compiler::ir::dfs::{ DFSIter, DFSWithScopeIter, Event, EventContext, dfs_common, }; -use crate::compiler::{FilesizeBounds, RegexSetId}; +use crate::compiler::{FilesizeBounds, HeaderConstraint, RegexSetId}; use crate::re; use crate::symbols::Symbol; use crate::types::Value::Const; @@ -310,6 +311,17 @@ impl Pattern { } } } + + pub fn set_header_constraints(&mut self, constraints: &HeaderConstraint) { + match self { + Pattern::Text(literal) => { + literal.header_constraints = constraints.clone(); + } + Pattern::Regexp(regexp) | Pattern::Hex(regexp) => { + regexp.header_constraints = constraints.clone(); + } + } + } } #[derive(Clone, Eq, Hash, PartialEq)] @@ -321,6 +333,7 @@ pub(crate) struct LiteralPattern { pub base64_alphabet: Option, pub base64wide_alphabet: Option, pub filesize_bounds: FilesizeBounds, + pub header_constraints: HeaderConstraint, } #[derive(Clone, Eq, Hash, PartialEq)] @@ -329,6 +342,7 @@ pub(crate) struct RegexpPattern { pub hir: re::hir::Hir, pub anchored_at: Option, pub filesize_bounds: FilesizeBounds, + pub header_constraints: HeaderConstraint, } /// The index of a pattern in the rule that declares it. @@ -992,6 +1006,251 @@ impl IR { result } + + pub fn header_constraints( + &self, + pattern_prefix_lookup: impl Fn(PatternIdx) -> Option>, + ) -> HeaderConstraint { + let mut constrained_bytes = BTreeMap::new(); + let mut unsatisfiable = false; + let mut dfs = self.dfs_iter(self.root.unwrap()); + + while let Some(evt) = dfs.next() { + let expr = match evt { + Event::Enter((_, expr, _)) => expr, + _ => continue, + }; + match expr { + Expr::Eq { lhs, rhs } => { + self.extract_header_constraints_from_eq( + *lhs, + *rhs, + &mut constrained_bytes, + &mut unsatisfiable, + ); + } + Expr::PatternMatch { pattern, anchor } => { + if let MatchAnchor::At(offset_expr) = anchor + && let Some(0) = + self.get(*offset_expr).try_as_const_integer() + && let Some(prefix_bytes) = + pattern_prefix_lookup(*pattern) + { + for (i, &b) in prefix_bytes.iter().enumerate() { + match constrained_bytes.entry(i) { + Entry::Occupied(entry) => { + if *entry.get() != b { + unsatisfiable = true; + break; + } + } + Entry::Vacant(entry) => { + entry.insert(b); + } + } + } + } + } + _ => {} + } + if unsatisfiable { + break; + } + if !matches!(expr, Expr::And { .. }) { + dfs.prune(); + } + } + + if unsatisfiable { + return HeaderConstraint::Unsatisfiable; + } + + // If the first byte in `constrained_bytes` is at offset 0, we can + // return HeaderConstraint::Constrained. + if let Some((0, _)) = constrained_bytes.first_key_value() { + HeaderConstraint::Constrained( + // Take only the bytes at consecutive offsets starting at 0. + constrained_bytes + .into_iter() + .enumerate() + .map_while( + |(i, (offset, byte))| { + if i == offset { Some(byte) } else { None } + }, + ) + .collect(), + ) + } else { + HeaderConstraint::Unconstrained + } + } + + fn extract_header_constraints_from_eq( + &self, + lhs: ExprId, + rhs: ExprId, + constrained_bytes: &mut BTreeMap, + unsatisfiable: &mut bool, + ) { + if let Some(val) = self.get(rhs).try_as_const_integer() + && self.apply_int_read_constraint( + constrained_bytes, + unsatisfiable, + lhs, + val, + ) + { + return; + } + if let Some(val) = self.get(lhs).try_as_const_integer() { + self.apply_int_read_constraint( + constrained_bytes, + unsatisfiable, + rhs, + val, + ); + } + } + + fn add_constraint( + &self, + constrained_bytes: &mut BTreeMap, + unsatisfiable: &mut bool, + offset: usize, + value: u8, + ) { + if *unsatisfiable { + return; + } + match constrained_bytes.entry(offset) { + Entry::Occupied(entry) => { + if *entry.get() != value { + *unsatisfiable = true; + } + } + Entry::Vacant(entry) => { + entry.insert(value); + } + } + } + + fn apply_int_read_constraint( + &self, + constrained_bytes: &mut BTreeMap, + unsatisfiable: &mut bool, + expr_id: ExprId, + val: i64, + ) -> bool { + let func_call = match self.get(expr_id) { + Expr::FuncCall(func_call) => func_call, + _ => return false, + }; + + if let Some(offset) = func_call + .args + .first() + .and_then(|arg| self.get(*arg).try_as_const_integer()) + && offset >= 0 + { + match func_call.plain_name() { + "uint8" | "int8" | "uint8be" | "int8be" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + val as u8, + ); + return true; + } + "uint16" | "int16" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + (val as u16 & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + ((val as u16 >> 8) & 0xff) as u8, + ); + return true; + } + "uint16be" | "int16be" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + ((val as u16 >> 8) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + (val as u16 & 0xff) as u8, + ); + return true; + } + "uint32" | "int32" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + (val as u32 & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + ((val as u32 >> 8) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 2, + ((val as u32 >> 16) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 3, + ((val as u32 >> 24) & 0xff) as u8, + ); + return true; + } + "uint32be" | "int32be" => { + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize, + ((val as u32 >> 24) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 1, + ((val as u32 >> 16) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 2, + ((val as u32 >> 8) & 0xff) as u8, + ); + self.add_constraint( + constrained_bytes, + unsatisfiable, + offset as usize + 3, + (val as u32 & 0xff) as u8, + ); + return true; + } + _ => {} + } + } + false + } } impl IR { @@ -2367,6 +2626,12 @@ impl FuncCall { pub fn mangled_name(&self) -> &str { self.signature().mangled_name.as_str() } + + /// Returns the plain function name, without argument or return type + /// information (i.e: everything before the `@` in the name). + pub fn plain_name(&self) -> &str { + self.signature().mangled_name.plain_name() + } } /// An `of` expression with a tuple of expressions (e.g. `1 of (true, false)`). diff --git a/lib/src/compiler/mod.rs b/lib/src/compiler/mod.rs index 9f49a9c06..182a41f42 100644 --- a/lib/src/compiler/mod.rs +++ b/lib/src/compiler/mod.rs @@ -337,6 +337,19 @@ pub struct Compiler<'a> { /// `FilesizeBounds{start: Bound::Unbounded, end: Bound:Excluded(1000)}`. filesize_bounds: FxHashMap, + /// Map that associates a `PatternId` to a certain constraint on the + /// file header (e.g. magic bytes at offset 0), if any. + /// + /// A condition like `uint16(0) == 0x5A4D and $a` or `$mz at 0 and $a` + /// (were $mz = "MZ") only matches if the file starts with "MZ" (0x5A4D). + /// In this case the map will contain an entry associating `$a` to a + /// `HeaderConstraint` that requires the file to start with those two + /// bytes. + /// + /// This allows skipping pattern checks entirely if the scanned data doesn't + /// start with the expected header prefix. + header_constraints: FxHashMap, + /// A vector with all the rules that has been compiled. A [`RuleId`] is /// an index in this vector. rules: Vec, @@ -502,6 +515,7 @@ impl<'a> Compiler<'a> { banned_modules: FxHashMap::default(), ignored_rules: FxHashMap::default(), filesize_bounds: FxHashMap::default(), + header_constraints: FxHashMap::default(), root_struct: Struct::new().make_root(), report_builder: ReportBuilder::new(), lit_pool: BStringPool::new(), @@ -813,6 +827,7 @@ impl<'a> Compiler<'a> { re_code: self.re_code, warnings: self.warnings.into(), filesize_bounds: self.filesize_bounds, + header_constraints: self.header_constraints, regex_sets: self.regex_sets, fast_scan_patterns: self.fast_scan_patterns, }; @@ -1668,6 +1683,18 @@ impl Compiler<'_> { // `filesize`, if any. let filesize_bounds = self.ir.filesize_bounds(); + // Analyze the condition and determine if it imposes some constraint + // to the file header (ex: `uint16(0) == 0x5a4d`). + let header_constraints = self.ir.header_constraints(|pat_idx| { + let pat = &rule_patterns[pat_idx.as_usize()]; + match pat.pattern() { + Pattern::Text(lit) => Some(lit.text.as_bytes().to_vec()), + Pattern::Regexp(re) | Pattern::Hex(re) => { + re.hir.as_literal_bytes().map(|bytes| bytes.to_vec()) + } + } + }); + // Set the bounds to all patterns in the rule. This must be done // before assigning the PatternId to each pattern, as the filesize // bounds are taken into account when determining if the pattern @@ -1678,6 +1705,15 @@ impl Compiler<'_> { } } + // Set header constraints to all patterns in the rule. + if !header_constraints.unconstrained() { + for pattern in &mut rule_patterns { + pattern + .pattern_mut() + .set_header_constraints(&header_constraints); + } + } + if let Some(w) = &mut self.ir_writer { writeln!(w, "RULE {}", rule.identifier.name).unwrap(); writeln!(w, "{:?}", self.ir).unwrap(); @@ -1810,6 +1846,17 @@ impl Compiler<'_> { "modifying the file size bounds of an existing pattern" ) } + if !header_constraints.unconstrained() + && self + .header_constraints + .insert(*pattern_id, header_constraints.clone()) + .is_some() + { + // This should not happen. + panic!( + "modifying the header constraints of an existing pattern" + ) + } pending_patterns.remove(pattern_id); } } diff --git a/lib/src/compiler/rules.rs b/lib/src/compiler/rules.rs index 97232d6c0..16a6fafe1 100644 --- a/lib/src/compiler/rules.rs +++ b/lib/src/compiler/rules.rs @@ -119,6 +119,9 @@ pub struct Rules { pub(in crate::compiler) filesize_bounds: FxHashMap, + pub(in crate::compiler) header_constraints: + FxHashMap, + /// Vector that contains the [`SubPatternId`] for sub-patterns that can /// match only at a fixed offset within the scanned data. These sub-patterns /// are not added to the Aho-Corasick automaton. @@ -577,6 +580,14 @@ impl Rules { self.filesize_bounds.get(&pattern_id) } + #[inline] + pub(crate) fn header_constraints( + &self, + pattern_id: PatternId, + ) -> Option<&HeaderConstraint> { + self.header_constraints.get(&pattern_id) + } + #[inline] pub(crate) fn is_fast_scan(&self, pattern_id: PatternId) -> bool { *self.fast_scan_patterns.get(usize::from(pattern_id)).unwrap() @@ -851,6 +862,34 @@ impl FilesizeBounds { } } +/// Describes the requirements on the file header imposed by a rule condition. +/// +/// For example, the condition `uint32(0) == 0x464c457f` requires that the first +/// 4 bytes of the file are `0x7f, 0x45, 0x4c, 0x46`. +#[derive( + Debug, PartialEq, Serialize, Deserialize, Clone, Hash, Eq, Default, +)] +pub(crate) enum HeaderConstraint { + #[default] + Unconstrained, + Unsatisfiable, + Constrained(Vec), +} + +impl HeaderConstraint { + pub fn unconstrained(&self) -> bool { + matches!(self, Self::Unconstrained) + } + + pub fn is_satisfied(&self, data: &[u8]) -> bool { + match self { + Self::Unconstrained => true, + Self::Unsatisfiable => false, + Self::Constrained(bytes) => data.starts_with(bytes), + } + } +} + /// Represents an atom extracted from a pattern and added to the Aho-Corasick /// automata. /// diff --git a/lib/src/scanner/context.rs b/lib/src/scanner/context.rs index 7e3786502..f769cbbcb 100644 --- a/lib/src/scanner/context.rs +++ b/lib/src/scanner/context.rs @@ -171,6 +171,9 @@ pub struct ScanContext<'r, 'd> { pub(crate) console_log: Option>, /// Virtual Machines used for executing regexps. pub(crate) vm: VM<'r>, + /// Patterns that are disabled for the current scan (e.g. because they don't + /// comply with filesize bounds or header constraints). + pub(crate) disabled_patterns: FxHashSet, /// Hash map that tracks the time spend on each pattern. Keys are pattern /// PatternIds and values are the cumulative time spent on verifying each /// pattern. @@ -541,6 +544,9 @@ impl ScanContext<'_, '_> { // Free all runtime objects left around by previous scans. self.runtime_objects.clear(); + // Clear the set that tracks the disabled patterns. + self.disabled_patterns.clear(); + // Clear the array that tracks the patterns that reached the maximum // number of patterns. self.tracker.limit_reached.clear(); @@ -764,10 +770,8 @@ impl ScanContext<'_, '_> { &mut self, base: usize, data: &[u8], - block_scanning_mode: bool, ) -> Result<(), ScanError> { let ac = self.compiled_rules.ac_automaton(); - let filesize = self.get_filesize(); #[cfg(feature = "logging")] let mut atom_matches = 0_usize; @@ -792,8 +796,6 @@ impl ScanContext<'_, '_> { match_offset, base, data, - filesize, - block_scanning_mode, ); }); } @@ -814,8 +816,6 @@ impl ScanContext<'_, '_> { ac_match.start(), base, data, - filesize, - block_scanning_mode, ); } } @@ -834,8 +834,6 @@ impl ScanContext<'_, '_> { match_start: usize, base: usize, data: &[u8], - filesize: i64, - block_scanning_mode: bool, ) { let atoms = self.compiled_rules.atoms(); let atom = unsafe { atoms.get_unchecked(atom_idx) }; @@ -855,15 +853,11 @@ impl ScanContext<'_, '_> { let (pattern_id, sub_pattern) = &self.compiled_rules.get_sub_pattern(sub_pattern_id); - if self.tracker.limit_reached.contains(pattern_id) { + if self.disabled_patterns.contains(pattern_id) { return; } - if !block_scanning_mode - && let Some(bounds) = - self.compiled_rules.filesize_bounds(*pattern_id) - && !bounds.contains(filesize) - { + if self.tracker.limit_reached.contains(pattern_id) { return; } @@ -1073,6 +1067,31 @@ impl ScanContext<'_, '_> { _ => panic!(), }; + if !block_scanning_mode { + let filesize = self.get_filesize(); + for pattern_id in 0..self.compiled_rules.num_patterns() { + let pattern_id = PatternId::from(pattern_id); + if let Some(bounds) = + self.compiled_rules.filesize_bounds(pattern_id) + && !bounds.contains(filesize) + { + self.disabled_patterns.insert(pattern_id); + } + } + } + + if base == 0 { + for pattern_id in 0..self.compiled_rules.num_patterns() { + let pattern_id = PatternId::from(pattern_id); + if let Some(constraints) = + self.compiled_rules.header_constraints(pattern_id) + && !constraints.is_satisfied(data) + { + self.disabled_patterns.insert(pattern_id); + } + } + } + #[cfg(any(feature = "rules-profiling", feature = "logging"))] let scan_start = self.clock.raw(); @@ -1080,8 +1099,7 @@ impl ScanContext<'_, '_> { // match at a single known offset within the data. self.verify_anchored_patterns(base, data); - let result = match self.ac_search_loop(base, data, block_scanning_mode) - { + let result = match self.ac_search_loop(base, data) { Ok(_) => { self.scan_state = state; Ok(()) @@ -1128,6 +1146,10 @@ impl ScanContext<'_, '_> { .anchored_sub_patterns() .iter() .map(|id| (id, self.compiled_rules.get_sub_pattern(*id))) + // Disabled patterns are ignored. + .filter(|(_, (pattern_id, _))| { + !self.disabled_patterns.contains(pattern_id) + }) { match sub_pattern { SubPattern::Literal { @@ -1924,6 +1946,7 @@ pub fn create_wasm_store_and_ctx<'r>( pike_vm: PikeVM::new(rules.re_code()), fast_vm: FastVM::new(rules.re_code()), }, + disabled_patterns: FxHashSet::default(), custom_base64_engine_cache: Vec::new(), #[cfg(feature = "rules-profiling")] time_spent_in_pattern: FxHashMap::default(), diff --git a/lib/src/tests/mod.rs b/lib/src/tests/mod.rs index bc5aa2206..b5c75e1ec 100644 --- a/lib/src/tests/mod.rs +++ b/lib/src/tests/mod.rs @@ -3948,3 +3948,159 @@ fn short_circuit() { b"foobar" ); } + +#[test] +fn header_constraints_optimization() { + // ELF magic check. + rule_true!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and $a + } + "#, + b"\x7fELF\0\0\0\0" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and $a + } + "#, + b"\0\0\0\0ELF" + ); + + // PE magic check. + rule_true!( + r#" + rule test { + strings: + $a = "PE" + condition: + uint16(0) == 0x5a4d and $a + } + "#, + b"MZ\0\0PE" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "PE" + condition: + uint16(0) == 0x5a4d and $a + } + "#, + b"\0\0MZPE" + ); + + // Pattern at 0 check. + rule_true!( + r#" + rule test { + strings: + $a = "MZ" + condition: + $a at 0 + } + "#, + b"MZ" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "MZ" + condition: + $a at 0 + } + "#, + b"\0MZ" + ); + + // Multiple constraints combined. + rule_true!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and uint16(4) == 0x0102 and $a + } + "#, + b"\x7fELF\x02\x01" + ); + + rule_false!( + r#" + rule test { + strings: + $a = "ELF" + condition: + uint32(0) == 0x464c457f and uint16(4) == 0x0102 and $a + } + "#, + b"\x7fELF\x99\x99" + ); + + // Deduplication test: A pattern used in both a constrained and an unconstrained rule. + // When the header does not match, only the unconstrained rule should match (exactly 1 rule). + rule_true!( + r#" + rule constrained { + strings: + $a = "foo" + condition: + uint32(0) == 0x464c457f and $a + } + rule unconstrained { + strings: + $a = "foo" + condition: + $a + } + "#, + b"\0\0\0\0foo" + ); + + // When the header matches, both rules should match (exactly 2 rules). + test_rule!( + r#" + rule constrained { + strings: + $a = "foo" + condition: + uint32(0) == 0x464c457f and $a + } + rule unconstrained { + strings: + $a = "foo" + condition: + $a + } + "#, + b"\x7fELFfoo", + 2 + ); + + // Non-contiguous offsets for uint8. + rule_true!( + r#" + rule constrained { + strings: + $a = "MZ" + condition: + uint8(0) == 0x00 and uint8(2) == 0x5a and $a + } + "#, + b"\0MZ" + ); +} diff --git a/lib/src/types/func.rs b/lib/src/types/func.rs index 100615790..a62444238 100644 --- a/lib/src/types/func.rs +++ b/lib/src/types/func.rs @@ -86,9 +86,13 @@ impl MangledFnName { pub fn as_str(&self) -> &str { self.0.as_str() } -} -impl MangledFnName { + /// Returns the plain function name, without argument or return type + /// information (i.e: everything before the `@` in the name). + pub fn plain_name(&self) -> &str { + self.0.as_str().split("@").next().unwrap() + } + /// Returns the types of arguments and return value for the function. pub fn unmangle(&self) -> (Vec<(&str, TypeValue)>, TypeValue) { let (_fn_name, arg_names_and_types, ret_type) = From ceb06cd88633129d7db99e0ae1b2ea4daffaec3c Mon Sep 17 00:00:00 2001 From: zdiff Date: Mon, 8 Jun 2026 06:00:45 -0400 Subject: [PATCH 5/5] chore: go testifylint (#674) * Addressed testifylint errors * Replaced deprecated ioutil.TempFile with os.CreateTemp * Formatted with gofumpt --------- Co-authored-by: zdiff --- go/compiler.go | 9 +++++---- go/compiler_test.go | 48 ++++++++++++++++++++++----------------------- go/example_test.go | 1 - go/main.go | 6 +++--- go/scanner.go | 19 +++++++++--------- go/scanner_test.go | 24 ++++++++++++----------- 6 files changed, 54 insertions(+), 53 deletions(-) diff --git a/go/compiler.go b/go/compiler.go index 3467f24d3..9b61b7db3 100644 --- a/go/compiler.go +++ b/go/compiler.go @@ -2,6 +2,7 @@ package yara_x // #include import "C" + import ( "encoding/json" "errors" @@ -28,7 +29,7 @@ type CompileOption func(c *Compiler) error // // Valid value types include: int, int32, int64, bool, string, float32 and // float64. -func Globals(vars map[string]interface{}) CompileOption { +func Globals(vars map[string]any) CompileOption { return func(c *Compiler) error { for ident, value := range vars { c.vars[ident] = value @@ -289,7 +290,7 @@ type Compiler struct { includesEnabled bool ignoredModules map[string]bool bannedModules map[string]bannedModule - vars map[string]interface{} + vars map[string]any features []string includeDirs []string maxWarnings *int @@ -301,7 +302,7 @@ func NewCompiler(opts ...CompileOption) (*Compiler, error) { includesEnabled: true, ignoredModules: make(map[string]bool), bannedModules: make(map[string]bannedModule), - vars: make(map[string]interface{}), + vars: make(map[string]any), features: make([]string, 0), includeDirs: make([]string, 0), } @@ -558,7 +559,7 @@ func (c *Compiler) DefineGlobal(ident string, value interface{}) error { ret = C.int(C.yrx_compiler_define_global_float(c.cCompiler, cIdent, C.double(v))) case float64: ret = C.int(C.yrx_compiler_define_global_float(c.cCompiler, cIdent, C.double(v))) - case map[string]interface{}, []interface{}: + case map[string]any, []any: jsonStr, err := json.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal '%s' to json: '%v'", ident, err) diff --git a/go/compiler_test.go b/go/compiler_test.go index 3c85322b3..9f5e44de0 100644 --- a/go/compiler_test.go +++ b/go/compiler_test.go @@ -3,16 +3,16 @@ package yara_x import ( "bytes" "fmt" - "io/ioutil" "os" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNamespaces(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.NewNamespace("foo") c.AddSource("rule test { condition: true }") @@ -26,7 +26,7 @@ func TestNamespaces(t *testing.T) { func TestGlobals(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) x := map[string]any{"a": map[string]any{"a": "b"}, "b": "d"} y := []any{"z"} @@ -54,7 +54,7 @@ func TestUnsupportedModules(t *testing.T) { rule test { condition: true }`, IgnoreModule("unsupported_module")) - assert.NoError(t, err) + require.NoError(t, err) scanResults, _ := r.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) } @@ -85,8 +85,8 @@ func TestDisabledIncludes(t *testing.T) { } func TestIncludes(t *testing.T) { - file, err := ioutil.TempFile("", "prefix") - assert.NoError(t, err) + file, err := os.CreateTemp(t.TempDir(), "prefix") + require.NoError(t, err) defer os.Remove(file.Name()) @@ -94,14 +94,14 @@ func TestIncludes(t *testing.T) { fmt.Sprintf(`include "%s"`, file.Name()), IncludeDir(os.TempDir())) - assert.NoError(t, err) + require.NoError(t, err) } func TestRelaxedReSyntax(t *testing.T) { r, err := Compile(` rule test { strings: $a = /\Release/ condition: $a }`, RelaxedReSyntax(true)) - assert.NoError(t, err) + require.NoError(t, err) scanResults, _ := r.Scan([]byte("Release")) assert.Len(t, scanResults.MatchingRules(), 1) } @@ -110,7 +110,7 @@ func TestConditionOptimization(t *testing.T) { _, err := Compile(` rule test { condition: true }`, ConditionOptimization(true)) - assert.NoError(t, err) + require.NoError(t, err) } func TestErrorOnSlowPattern(t *testing.T) { @@ -129,14 +129,14 @@ func TestErrorOnSlowLoop(t *testing.T) { func TestSerialization(t *testing.T) { r, err := Compile("rule test { condition: true }") - assert.NoError(t, err) + require.NoError(t, err) var buf bytes.Buffer // Write rules into buffer n, err := r.WriteTo(&buf) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, buf.Bytes(), int(n)) // Read rules from buffer @@ -157,7 +157,7 @@ func TestVariables(t *testing.T) { assert.Len(t, scanResults.MatchingRules(), 1) c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.DefineGlobal("var", 1234) c.AddSource("rule test { condition: var == 1234 }") @@ -177,7 +177,7 @@ func TestVariables(t *testing.T) { c.DefineGlobal("var", false) c.AddSource("rule test { condition: var }") scanResults, _ = NewScanner(c.Build()).Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 0) + assert.Empty(t, scanResults.MatchingRules()) c.DefineGlobal("var", "foo") c.AddSource("rule test { condition: var == \"foo\" }") @@ -207,26 +207,26 @@ func TestCompilerFeatures(t *testing.T) { rules := `import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar }` _, err := Compile(rules) - assert.EqualError(t, err, `error[E100]: foo is required + require.EqualError(t, err, `error[E100]: foo is required --> line:1:57 | 1 | import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar } | ^^^^^^^^^^^^^^^^^^^^ this field was used without foo`) _, err = Compile(rules, WithFeature("foo")) - assert.EqualError(t, err, `error[E100]: bar is required + require.EqualError(t, err, `error[E100]: bar is required --> line:1:57 | 1 | import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar } | ^^^^^^^^^^^^^^^^^^^^ this field was used without bar`) _, err = Compile(rules, WithFeature("foo"), WithFeature("bar")) - assert.NoError(t, err) + require.NoError(t, err) } func TestErrors(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource("rule test_1 { condition: true }") assert.Equal(t, []CompileError{}, c.Errors()) @@ -291,13 +291,13 @@ func TestErrors(t *testing.T) { func TestRules(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource(`rule test_1 : tag1 tag2 { condition: true }`) - assert.NoError(t, err) + require.NoError(t, err) c.AddSource(`rule test_2 { meta: @@ -308,7 +308,7 @@ func TestRules(t *testing.T) { condition: true }`) - assert.NoError(t, err) + require.NoError(t, err) rules := c.Build() assert.Equal(t, 2, rules.Count()) @@ -324,7 +324,7 @@ func TestRules(t *testing.T) { assert.Equal(t, []string{"tag1", "tag2"}, slice[0].Tags()) assert.Equal(t, []string{}, slice[1].Tags()) - assert.Len(t, slice[0].Metadata(), 0) + assert.Empty(t, slice[0].Metadata()) assert.Len(t, slice[1].Metadata(), 4) assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier()) @@ -342,7 +342,7 @@ func TestRules(t *testing.T) { func TestImportsIter(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource(` import "pe" @@ -351,7 +351,7 @@ func TestImportsIter(t *testing.T) { condition: true }`) - assert.NoError(t, err) + require.NoError(t, err) rules := c.Build() imports := rules.Imports() @@ -363,7 +363,7 @@ func TestImportsIter(t *testing.T) { func TestWarnings(t *testing.T) { c, err := NewCompiler() - assert.NoError(t, err) + require.NoError(t, err) c.AddSource("rule test { strings: $a = {01 [0-1][0-1] 02 } condition: $a }") diff --git a/go/example_test.go b/go/example_test.go index 735d65830..bf58e9960 100644 --- a/go/example_test.go +++ b/go/example_test.go @@ -50,7 +50,6 @@ func Example_compilerAndScanner() { condition: $bar }`) - if err != nil { panic(err) } diff --git a/go/main.go b/go/main.go index 1d77182db..20e18baa7 100644 --- a/go/main.go +++ b/go/main.go @@ -266,7 +266,7 @@ type Pattern struct { // Metadata represents a metadata in a Rule. type Metadata struct { identifier string - value interface{} + value any } // Match contains information about the offset where a match occurred and @@ -358,7 +358,7 @@ func (m *Metadata) Identifier() string { } // Value associated to the metadata. -func (m *Metadata) Value() interface{} { +func (m *Metadata) Value() any { return m.value } @@ -474,7 +474,7 @@ func metadataCallback(metadata *C.YRX_METADATA, handle C.uintptr_t) { panic("matchCallback didn't receive a *[]Metadata") } - var value interface{} + var value any switch metadata.value_type { case C.YRX_I64: diff --git a/go/scanner.go b/go/scanner.go index c1926c56e..67c08ad8f 100644 --- a/go/scanner.go +++ b/go/scanner.go @@ -31,9 +31,7 @@ import ( "runtime/cgo" "time" "unsafe" -) -import ( "google.golang.org/protobuf/proto" ) @@ -163,7 +161,7 @@ func (s *Scanner) SetGlobal(ident string, value interface{}) error { ret = C.int(C.yrx_scanner_set_global_str(s.cScanner, cIdent, cValue)) case float64: ret = C.int(C.yrx_scanner_set_global_float(s.cScanner, cIdent, C.double(v))) - case map[string]interface{}, []interface{}: + case map[string]any, []any: jsonStr, err := json.Marshal(v) if err != nil { return fmt.Errorf("failed to marshal '%s' to json: '%v'", ident, err) @@ -316,7 +314,8 @@ func slowestRulesCallback( rule *C.char, patternMatchingTime C.double, conditionExecTime C.double, - handle C.uintptr_t) { + handle C.uintptr_t, +) { h := cgo.Handle(handle) profilingInfo, ok := h.Value().(*[]ProfilingInfo) if !ok { @@ -358,17 +357,17 @@ func (s *Scanner) SlowestRules(n int) []ProfilingInfo { return profilingInfo } -/// ClearProfilingData resets the profiling data collected during rule execution -/// across scanned files. Use it to start a new profiling session, ensuring the -/// results reflect only the data gathered after this method is called. +// ClearProfilingData resets the profiling data collected during rule execution +// across scanned files. Use it to start a new profiling session, ensuring the +// results reflect only the data gathered after this method is called. // // In order to use this function, the YARA-X C library must be built with // support for rules profiling by enabling the `rules-profiling` feature. // Otherwise, calling this function will cause a panic. func (s *Scanner) ClearProfilingData() { - if C.yrx_scanner_clear_profiling_data(s.cScanner) == C.YRX_NOT_SUPPORTED { - panic("ClearProfilingData requires that the YARA-X C library is built with the `rules-profiling` feature") - } + if C.yrx_scanner_clear_profiling_data(s.cScanner) == C.YRX_NOT_SUPPORTED { + panic("ClearProfilingData requires that the YARA-X C library is built with the `rules-profiling` feature") + } } // Destroy destroys the scanner. diff --git a/go/scanner_test.go b/go/scanner_test.go index ba8aa3286..49dd363cc 100644 --- a/go/scanner_test.go +++ b/go/scanner_test.go @@ -2,11 +2,13 @@ package yara_x import ( "bytes" - "github.com/stretchr/testify/assert" "os" "runtime" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestScanner1(t *testing.T) { @@ -18,7 +20,7 @@ func TestScanner1(t *testing.T) { assert.Len(t, matchingRules, 1) assert.Equal(t, "t", matchingRules[0].Identifier()) assert.Equal(t, "default", matchingRules[0].Namespace()) - assert.Len(t, matchingRules[0].Patterns(), 0) + assert.Empty(t, matchingRules[0].Patterns()) scanResults, _ = s.Scan(nil) matchingRules = scanResults.MatchingRules() @@ -26,7 +28,7 @@ func TestScanner1(t *testing.T) { assert.Len(t, matchingRules, 1) assert.Equal(t, "t", matchingRules[0].Identifier()) assert.Equal(t, "default", matchingRules[0].Namespace()) - assert.Len(t, matchingRules[0].Patterns(), 0) + assert.Empty(t, matchingRules[0].Patterns()) } func TestScanner2(t *testing.T) { @@ -59,7 +61,7 @@ func TestScanner3(t *testing.T) { s.SetGlobal("var_bool", false) scanResults, _ = s.Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 0) + assert.Empty(t, scanResults.MatchingRules()) } func TestScanner4(t *testing.T) { @@ -69,17 +71,17 @@ func TestScanner4(t *testing.T) { s := NewScanner(r) scanResults, _ := s.Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 0) + assert.Empty(t, scanResults.MatchingRules()) - assert.NoError(t, s.SetGlobal("var_int", 1)) + require.NoError(t, s.SetGlobal("var_int", 1)) scanResults, _ = s.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) - assert.NoError(t, s.SetGlobal("var_int", int32(1))) + require.NoError(t, s.SetGlobal("var_int", int32(1))) scanResults, _ = s.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) - assert.NoError(t, s.SetGlobal("var_int", int64(1))) + require.NoError(t, s.SetGlobal("var_int", int64(1))) scanResults, _ = s.Scan([]byte{}) assert.Len(t, scanResults.MatchingRules(), 1) } @@ -89,12 +91,12 @@ func TestScanFile(t *testing.T) { s := NewScanner(r) // Create a temporary file with some content - f, err := os.CreateTemp("", "example") - assert.NoError(t, err) + f, err := os.CreateTemp(t.TempDir(), "example") + require.NoError(t, err) defer os.Remove(f.Name()) _, err = f.Write([]byte("foobar")) - assert.NoError(t, err) + require.NoError(t, err) f.Close() scanResults, _ := s.ScanFile(f.Name())