forked from hackerlibs/rag-code-sorting-search
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrag_search_code.py
More file actions
95 lines (75 loc) · 3.05 KB
/
rag_search_code.py
File metadata and controls
95 lines (75 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
## python rag : use treesitter the python code to search ,and then embed calc consin search list sort & embed_code use ollama nomic-embed-text
import os
import tree_sitter
import tree_sitter_python as tspython
from tree_sitter import Language, Parser
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import requests
import json
# Step 1: Set up Tree-sitter for Python
PY_LANGUAGE = Language(tspython.language()) #Language(os.path.expanduser('~/.tree-sitter/python.so'), 'python')
parser = Parser(PY_LANGUAGE)
def search_code(code, query):
tree = parser.parse(bytes(code, "utf8"))
root_node = tree.root_node
results = []
for node in root_node.children:
if node.type == 'function_definition':
function_name = node.child_by_field_name('name').text.decode('utf8')
if query.lower() in function_name.lower():
results.append((function_name, node.start_point, node.end_point))
return results
# Step 2: Embed code snippets using Ollama Nomic Embed Text
def embed_code(code_snippets):
embeddings = []
for snippet in code_snippets:
response = requests.post('http://localhost:11434/api/embeddings',
json={
"model": "nomic-embed-text",
"prompt": snippet
})
if response.status_code == 200:
embedding = response.json()['embedding']
embeddings.append(embedding)
else:
print(f"Error embedding snippet: {response.status_code}")
embeddings.append([0] * 768) # Default to zero vector on error
return np.array(embeddings)
# Step 3: Cosine similarity search
def cosine_search(query_embedding, code_embeddings):
similarities = cosine_similarity([query_embedding], code_embeddings)
return similarities.flatten()
# Step 4: Main RAG function
def rag_search(code, query):
# Search for relevant functions
search_results = search_code(code, query)
if not search_results:
return "No matching functions found."
# Extract function names and code snippets
function_names = [result[0] for result in search_results]
code_snippets = [code[result[1][0]:result[2][0]] for result in search_results]
# Embed code snippets
code_embeddings = embed_code(code_snippets)
# Embed query
query_embedding = embed_code([query])[0] # Take the first (and only) embedding
# Perform cosine similarity search
similarities = cosine_search(query_embedding, code_embeddings)
# Sort results by similarity
sorted_results = sorted(zip(function_names, similarities), key=lambda x: x[1], reverse=True)
return sorted_results
# Example usage
sample_code = """
def hello_world():
print("Hello, World!")
def greet_user(name):
print(f"Hello, {name}!")
def calculate_sum(a, b):
return a + b
"""
query = "greet"
results = rag_search(sample_code, query)
print(results)
# Mac run:
# @ prunp rag_search_code.py
# [('greet_user', np.float64(0.5930675716272089))]