Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 213 additions & 66 deletions kai-core/detect/detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,37 @@ import (
"kai-core/parse"
)

// functionNodeTypes maps each language to the AST node types that represent
// function-like declarations (functions, methods, etc.)
var functionNodeTypes = map[string][]string{
"js": {
"function_declaration", // function foo() {}
"method_definition", // class methods
"lexical_declaration", // const foo = () => {}
"variable_declaration", // var foo = function() {}
},
"ts": {
"function_declaration",
"method_definition",
"lexical_declaration",
"variable_declaration",
},
"py": {
"function_definition", // Both standalone functions and methods
},
"go": {
"function_declaration", // func Foo() {}
"method_declaration", // func (T) Method() {}
},
"rb": {
"method", // def foo
"singleton_method", // def self.foo
},
"rs": {
"function_item",
},
}

// ChangeCategory represents a type of change.
type ChangeCategory string

Expand Down Expand Up @@ -56,12 +87,12 @@ const (
DependencyUpdated ChangeCategory = "DEPENDENCY_UPDATED"

// Semantic config changes
FeatureFlagChanged ChangeCategory = "FEATURE_FLAG_CHANGED"
TimeoutChanged ChangeCategory = "TIMEOUT_CHANGED"
LimitChanged ChangeCategory = "LIMIT_CHANGED"
RetryConfigChanged ChangeCategory = "RETRY_CONFIG_CHANGED"
EndpointChanged ChangeCategory = "ENDPOINT_CHANGED"
CredentialChanged ChangeCategory = "CREDENTIAL_CHANGED"
FeatureFlagChanged ChangeCategory = "FEATURE_FLAG_CHANGED"
TimeoutChanged ChangeCategory = "TIMEOUT_CHANGED"
LimitChanged ChangeCategory = "LIMIT_CHANGED"
RetryConfigChanged ChangeCategory = "RETRY_CONFIG_CHANGED"
EndpointChanged ChangeCategory = "ENDPOINT_CHANGED"
CredentialChanged ChangeCategory = "CREDENTIAL_CHANGED"

// Schema/migration changes
SchemaFieldAdded ChangeCategory = "SCHEMA_FIELD_ADDED"
Expand Down Expand Up @@ -243,69 +274,57 @@ func GetAllFunctions(parsed *parse.ParsedFile, content []byte, lang ...string) m
l = lang[0]
}

switch l {
case "rb":
// Ruby: method and singleton_method nodes
for _, node := range parsed.FindNodesOfType("method") {
name := getFunctionName(node, content)
if name != "" {
body := getFunctionBody(node, content)
funcs[name] = &FuncInfo{Name: name, Node: node, Body: body}
}
}
for _, node := range parsed.FindNodesOfType("singleton_method") {
name := getFunctionName(node, content)
if name != "" {
body := getFunctionBody(node, content)
funcs["self."+name] = &FuncInfo{Name: "self." + name, Node: node, Body: body}
}
}

case "py":
// Python: function_definition nodes
for _, node := range parsed.FindNodesOfType("function_definition") {
name := getFunctionName(node, content)
if name != "" {
body := getFunctionBody(node, content)
funcs[name] = &FuncInfo{Name: name, Node: node, Body: body}
}
}

default:
// JS/TS/Go and others

// Function declarations: function foo() {}
for _, node := range parsed.FindNodesOfType("function_declaration") {
name := getFunctionName(node, content)
if name != "" {
body := getFunctionBody(node, content)
funcs[name] = &FuncInfo{Name: name, Node: node, Body: body}
}
}

// Arrow functions assigned to variables: const foo = () => {}
for _, node := range parsed.FindNodesOfType("lexical_declaration") {
name, arrowNode := getArrowFunctionName(node, content)
if name != "" && arrowNode != nil {
body := getFunctionBody(arrowNode, content)
funcs[name] = &FuncInfo{Name: name, Node: node, Body: body}
}
}
nodeTypes, ok := functionNodeTypes[l]
if !ok {
// Fallback to JS if language not in map
nodeTypes = functionNodeTypes["js"]
}

// Variable declarations: var foo = function() {}
for _, node := range parsed.FindNodesOfType("variable_declaration") {
name, funcNode := getVariableFunctionName(node, content)
if name != "" && funcNode != nil {
body := getFunctionBody(funcNode, content)
funcs[name] = &FuncInfo{Name: name, Node: node, Body: body}
// Search for all node types for this language
for _, nodeType := range nodeTypes {
for _, node := range parsed.FindNodesOfType(nodeType) {
var name string
var bodyNode *sitter.Node

// Handle special cases per node type
switch nodeType {
case "lexical_declaration":
// JS/TS: const foo = () => {}
name, bodyNode = getArrowFunctionName(node, content)
case "variable_declaration":
// JS/TS: var foo = function() {}
name, bodyNode = getVariableFunctionName(node, content)
case "singleton_method":
// Ruby: def self.foo
name = getFunctionName(node, content)
if name != "" {
name = "self." + name
}
bodyNode = node
case "method_declaration":
// Go: func (T) Method() {}
name = getGoMethodName(node, content)
bodyNode = node
case "function_definition":
// Python: check if inside a class
name = getPythonFunctionName(node, content)
bodyNode = node
case "method":
// Ruby: check if inside a class
name = getRubyMethodName(node, content)
bodyNode = node
case "function_item":
// Rust: check if inside impl block
name = getRustFunctionName(node, content)
bodyNode = node
default:
// Standard function extraction
name = getFunctionName(node, content)
bodyNode = node
}
}

// Method definitions in classes/objects
for _, node := range parsed.FindNodesOfType("method_definition") {
name := getFunctionName(node, content)
if name != "" {
body := getFunctionBody(node, content)
if name != "" && bodyNode != nil {
body := getFunctionBody(bodyNode, content)
funcs[name] = &FuncInfo{Name: name, Node: node, Body: body}
}
}
Expand Down Expand Up @@ -722,6 +741,134 @@ func getFunctionName(node *sitter.Node, content []byte) string {
return ""
}

// getGoMethodName extracts the qualified name for a Go method (e.g., "User.login")
func getGoMethodName(node *sitter.Node, content []byte) string {
var receiverType string
var methodName string

for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
switch child.Type() {
case "parameter_list":
// First parameter_list is the receiver
if receiverType == "" {
receiverType = getGoReceiverType(child, content)
}
case "field_identifier":
methodName = parse.GetNodeContent(child, content)
}
}

if methodName == "" {
return ""
}

if receiverType != "" {
return receiverType + "." + methodName
}
return methodName
}

// getGoReceiverType extracts the type from a Go method receiver parameter list
func getGoReceiverType(paramList *sitter.Node, content []byte) string {
for i := 0; i < int(paramList.ChildCount()); i++ {
child := paramList.Child(i)
if child.Type() == "parameter_declaration" {
for j := 0; j < int(child.ChildCount()); j++ {
typeChild := child.Child(j)
switch typeChild.Type() {
case "type_identifier":
return parse.GetNodeContent(typeChild, content)
case "pointer_type":
// Extract base type from pointer (e.g., "*User" -> "User")
for k := 0; k < int(typeChild.ChildCount()); k++ {
ptrChild := typeChild.Child(k)
if ptrChild.Type() == "type_identifier" {
return parse.GetNodeContent(ptrChild, content)
}
}
}
}
}
}
return ""
}

// getPythonFunctionName extracts qualified name for Python functions (e.g., "User.login")
func getPythonFunctionName(node *sitter.Node, content []byte) string {
funcName := getFunctionName(node, content)
if funcName == "" {
return ""
}

// Check if this function is inside a class
parent := node.Parent()
for parent != nil {
if parent.Type() == "class_definition" {
// Found parent class, get its name
for i := 0; i < int(parent.ChildCount()); i++ {
child := parent.Child(i)
if child.Type() == "identifier" {
className := parse.GetNodeContent(child, content)
return className + "." + funcName
}
}
}
parent = parent.Parent()
}
return funcName
}

// getRubyMethodName extracts qualified name for Ruby methods (e.g., "User#login")
func getRubyMethodName(node *sitter.Node, content []byte) string {
methodName := getFunctionName(node, content)
if methodName == "" {
return ""
}

// Check if this method is inside a class
parent := node.Parent()
for parent != nil {
if parent.Type() == "class" {
// Found parent class, get its name
for i := 0; i < int(parent.ChildCount()); i++ {
child := parent.Child(i)
if child.Type() == "constant" {
className := parse.GetNodeContent(child, content)
return className + "#" + methodName
}
}
}
parent = parent.Parent()
}
return methodName
}

// getRustFunctionName extracts qualified name for Rust functions (e.g., "User::login")
func getRustFunctionName(node *sitter.Node, content []byte) string {
funcName := getFunctionName(node, content)
if funcName == "" {
return ""
}

// Check if this function is inside an impl block
parent := node.Parent()
for parent != nil {
if parent.Type() == "impl_item" {
// Found impl block, get the type name
for i := 0; i < int(parent.ChildCount()); i++ {
child := parent.Child(i)
if child.Type() == "type_identifier" {
typeName := parse.GetNodeContent(child, content)
return typeName + "::" + funcName
}
}
}
parent = parent.Parent()
}
return funcName
}

func getFunctionParams(node *sitter.Node, content []byte) string {
for i := 0; i < int(node.ChildCount()); i++ {
child := node.Child(i)
Expand Down
Loading