Skip to content

Commit a160ff5

Browse files
committed
feat(functions): Implement COUNT function
The count function counts the number of items in the slice or substrings in the string that is matching a wildcard pattern.
1 parent e8570c9 commit a160ff5

5 files changed

Lines changed: 153 additions & 0 deletions

File tree

pkg/filter/filter_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,13 +304,15 @@ func TestProcFilter(t *testing.T) {
304304
{`ps.modules IN ('kernel32.dll')`, true},
305305
{`evt.name = 'CreateProcess' and evt.pid != ps.ppid`, true},
306306
{`ps.parent.name = 'svchost.exe'`, true},
307+
{`count(ps.modules, '*.dll') >= 2`, true},
307308

308309
{`ps.ancestor[0] = 'svchost.exe'`, true},
309310
{`ps.ancestor[0] = 'csrss.exe'`, false},
310311
{`ps.ancestor[1] = 'services.exe'`, true},
311312
{`ps.ancestor[2] = 'csrss.exe'`, true},
312313
{`ps.ancestor[3] = ''`, true},
313314
{`ps.ancestor intersects ('csrss.exe', 'services.exe', 'svchost.exe')`, true},
315+
{`count(ps.ancestor, '*.exe') = 3`, true},
314316

315317
{`foreach(ps._ancestors, $proc, $proc.name in ('csrss.exe', 'services.exe', 'System'))`, true},
316318
{`foreach(ps._ancestors, $proc, $proc.name in ('csrss.exe', 'services.exe', 'System') and ps.is_packaged, ps.is_packaged)`, true},

pkg/filter/ql/function.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ var funcs = map[string]FunctionDef{
8181
functions.GetRegValueFn.String(): &functions.GetRegValue{},
8282
functions.YaraFn.String(): &functions.Yara{},
8383
functions.ForeachFn.String(): &Foreach{},
84+
functions.CountFn.String(): &functions.Count{},
8485
}
8586

8687
// FunctionDef is the interface that all function definitions have to satisfy.

pkg/filter/ql/functions/count.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2021-present by Nedim Sabic Sabic
3+
* https://www.fibratus.io
4+
* All Rights Reserved.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package functions
20+
21+
import (
22+
"strings"
23+
24+
"github.com/rabbitstack/fibratus/pkg/util/wildcard"
25+
)
26+
27+
// Count counts the number of items in the slice or substrings
28+
// in the string that is matching a wildcard pattern.
29+
type Count struct{}
30+
31+
func (f Count) Call(args []interface{}) (any, bool) {
32+
if len(args) < 2 {
33+
return false, false
34+
}
35+
36+
var count int
37+
var caseInsensitive bool = true
38+
39+
pattern := parseString(1, args)
40+
41+
if len(args) > 2 {
42+
caseInsensitive, _ = args[2].(bool)
43+
}
44+
45+
switch s := args[0].(type) {
46+
case string:
47+
substrings := strings.Fields(s)
48+
for _, ss := range substrings {
49+
switch caseInsensitive {
50+
case true:
51+
if wildcard.Match(strings.ToLower(pattern), strings.ToLower(ss)) {
52+
count++
53+
}
54+
case false:
55+
if wildcard.Match(pattern, ss) {
56+
count++
57+
}
58+
}
59+
}
60+
case []string:
61+
for _, i := range s {
62+
switch caseInsensitive {
63+
case true:
64+
if wildcard.Match(strings.ToLower(pattern), strings.ToLower(i)) {
65+
count++
66+
}
67+
case false:
68+
if wildcard.Match(pattern, i) {
69+
count++
70+
}
71+
}
72+
}
73+
}
74+
75+
return count, true
76+
}
77+
78+
func (f Count) Desc() FunctionDesc {
79+
desc := FunctionDesc{
80+
Name: CountFn,
81+
Args: []FunctionArgDesc{
82+
{Keyword: "string|slice", Types: []ArgType{Field, BoundField, BoundSegment, BareBoundVariable, Func, String, Slice}, Required: true},
83+
{Keyword: "pattern", Types: []ArgType{String}, Required: true},
84+
{Keyword: "case_insensitive", Types: []ArgType{Bool}, Required: false},
85+
},
86+
}
87+
return desc
88+
}
89+
90+
func (f Count) Name() Fn { return CountFn }
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2021-present by Nedim Sabic Sabic
3+
* https://www.fibratus.io
4+
* All Rights Reserved.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package functions
20+
21+
import (
22+
"fmt"
23+
"testing"
24+
25+
"github.com/stretchr/testify/assert"
26+
)
27+
28+
func TestCount(t *testing.T) {
29+
var tests = []struct {
30+
args []any
31+
expected int
32+
}{
33+
{
34+
[]any{"hello world", "?orld"},
35+
1,
36+
},
37+
{
38+
[]any{"hello world", "saturn"},
39+
0,
40+
},
41+
{
42+
[]any{[]string{"C:\\Windows\\System32\\ntdll.dll", "C:\\Windows\\System32\\NTDLL.dll"}, "*ntdll.dll"},
43+
2,
44+
},
45+
{
46+
[]any{[]string{"C:\\Windows\\System32\\ntdll.dll", "C:\\Windows\\System32\\NTDLL.dll"}, "*ntdll.dll", false},
47+
1,
48+
},
49+
}
50+
51+
for i, tt := range tests {
52+
f := Count{}
53+
res, _ := f.Call(tt.args)
54+
assert.Equal(t, tt.expected, res, fmt.Sprintf("%d. result mismatch: exp=%v got=%v", i, tt.expected, res))
55+
}
56+
}

pkg/filter/ql/functions/types.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ const (
7474
YaraFn
7575
// ForeachFn represents the FOREACH function
7676
ForeachFn
77+
// CountFn reprsents the COUNT function
78+
CountFn
7779
)
7880

7981
// ArgType is the type alias for the argument value type.
@@ -228,6 +230,8 @@ func (f Fn) String() string {
228230
return "YARA"
229231
case ForeachFn:
230232
return "FOREACH"
233+
case CountFn:
234+
return "COUNT"
231235
default:
232236
return "UNDEFINED"
233237
}

0 commit comments

Comments
 (0)