Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions benchmarks/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
def test_rag_pedantic(benchmark):
"""
Benchmark test for RAG performance using pytest-benchmark's pedantic mode.

This test:
1. Initializes the RAG system with nilDB configuration
2. Sets up test parameters (prompt, number of chunks, clusters)
3. Performs a warm-up phase to stabilize measurements
4. Runs the benchmark with multiple iterations and rounds
5. Verifies the result is a list

Args:
benchmark: pytest-benchmark fixture for performance testing
"""
Expand All @@ -39,13 +39,15 @@ def test_rag_pedantic(benchmark):
subtract_query_id = os.getenv("QUERY_ID")

# Setup RAG instance
rag = asyncio.run(RAGVault.create(
ORG_CONFIG["nodes"],
ORG_CONFIG["org_credentials"],
schema_id=schema_id,
clusters_schema_id=clusters_schema_id,
subtract_query_id=subtract_query_id,
))
rag = asyncio.run(
RAGVault.create(
ORG_CONFIG["nodes"],
ORG_CONFIG["org_credentials"],
schema_id=schema_id,
clusters_schema_id=clusters_schema_id,
subtract_query_id=subtract_query_id,
)
)

prompt = "Who is Michelle Ross?"
num_chunks = 2
Expand All @@ -54,23 +56,23 @@ def test_rag_pedantic(benchmark):
def sync_runner():
"""
Synchronous wrapper for the async RAG execution.

This function:
1. Wraps the async top_num_chunks_execute in a synchronous context
2. Executes the RAG query with the configured parameters
3. Returns the retrieved chunks

Returns:
list: Retrieved chunks from the RAG system
"""
return asyncio.run(rag.top_num_chunks_execute(
prompt, num_chunks, False, num_clusters
))
return asyncio.run(
rag.top_num_chunks_execute(prompt, num_chunks, False, num_clusters)
)

# Warm up
for _ in range(10):
sync_runner()
#Actual benchmark
# Actual benchmark
result = benchmark.pedantic(sync_runner, iterations=10, rounds=5)

assert isinstance(result, list)
assert isinstance(result, str)
56 changes: 55 additions & 1 deletion src/nilrag/rag_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,51 @@ async def create(
await self.init()
return self

@classmethod
async def create_from_dict(
cls,
config: dict,
*args,
**kwargs,
) -> "RAGVault":
"""
Create a RAGVault instance from a dictionary
"""
# Check required keys explicitly
if "nodes" not in config:
raise ValueError("Missing required 'nodes' field in configuration")
if "org_secret_key" not in config or "org_did" not in config:
raise ValueError("Missing 'org_secret_key' or 'org_did' in configuration")

nodes = []
for node_data in config["nodes"]:
nodes.append({"url": node_data["url"], "did": node_data["did"]})
credentials = {
"secret_key": config["org_secret_key"],
"org_did": config["org_did"],
}

# Extract optional fields
with_clustering = config.get("with_clustering", None)
schema_id = config.get("schema_id", None)
clusters_schema_id = config.get("clusters_schema_id", None)
subtract_query_id = config.get("subtract_query_id", None)

# Construct object synchronously
self = cls(
nodes,
credentials,
*args,
schema_id=schema_id,
with_clustering=with_clustering,
clusters_schema_id=clusters_schema_id,
subtract_query_id=subtract_query_id,
**kwargs,
)
# Perform async initialization from SecretVaultWrapper (await SecretVaultWrapper.init())
await self.init()
return self

@classmethod
async def bootstrap(
cls,
Expand Down Expand Up @@ -337,6 +382,14 @@ async def top_num_chunks_execute(
{"_id": id, "chunks": nilql.decrypt(xor_key, chunk_shares)}
for id, chunk_shares in chunk_shares_by_id.items()
]
# 5: Format top results
formatted_results, formatted_results_time_sec = benchmark_time(
lambda: "\n".join(
f"- {str(result['chunks'])}" for result in top_num_chunks
),
enable=enable_benchmark,
)
relevant_context = f"\n\nRelevant Context:\n{formatted_results}"

# Print benchmarks, if enabled
if enable_benchmark:
Expand All @@ -352,9 +405,10 @@ async def top_num_chunks_execute(
\n decrypt: {decrypt_time_sec:.2f} seconds\
\n get top chunks ids: {top_num_chunks_ids_time_sec:.2f} seconds\
\n query top chunks: {query_top_chunks_time_sec:.2f} seconds\
\n format top num chunks: {formatted_results_time_sec:.2f} seconds\
"""
)
return top_num_chunks
return relevant_context

def nilai_chat_completion(
self,
Expand Down
43 changes: 37 additions & 6 deletions test/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_rag_with_nilql(self):
self.check_top_results(top_results, case.expected_results)

@unittest.skipUnless(RUN_OPTIONAL_TESTS, "Skipping optional test.")
async def test_top_num_chunks_execute(self):
async def test_1_top_num_chunks_execute(self):
"""
Test the RAG method with nilDB.
"""
Expand Down Expand Up @@ -308,12 +308,43 @@ async def test_top_num_chunks_execute(self):
print(json.dumps(top_chunks, indent=4))
print(f"Query took {end_time - start_time:.2f} seconds")

# Format top results as nilAI
formatted_results = "\n".join(
f"- {str(result['chunks'])}" for result in top_chunks
)
print(f"Relevant Context:\n{top_chunks}")

@unittest.skipUnless(RUN_OPTIONAL_TESTS, "Skipping optional test.")
async def test_2_top_num_chunks_execute(self):
"""
Test the RAG method with nilDB.
RAGVault is created from a dictionary
"""

# Load environment variables
load_dotenv(override=True)

schema_id = os.getenv("SCHEMA_ID")
clusters_schema_id = os.getenv("CLUSTERS_SCHEMA_ID")
subtract_query_id = os.getenv("QUERY_ID")

# Build config dictionary for create_from_dict
config = {
"nodes": ORG_CONFIG["nodes"],
"org_secret_key": ORG_CONFIG["org_credentials"]["secret_key"],
"org_did": ORG_CONFIG["org_credentials"]["org_did"],
"schema_id": schema_id,
"clusters_schema_id": clusters_schema_id,
"subtract_query_id": subtract_query_id,
}
# Initialize vault with clustering enabled
rag = await RAGVault.create_from_dict(config)

print("Perform nilRAG...")
start_time = time.time()
query = DEFAULT_PROMPT
top_chunks = await rag.top_num_chunks_execute(query, 2)
end_time = time.time()
print(json.dumps(top_chunks, indent=4))
print(f"Query took {end_time - start_time:.2f} seconds")

print(f"Relevant Context:\n{formatted_results}")
print(f"Relevant Context:\n{top_chunks}")


if __name__ == "__main__":
Expand Down