diff --git a/benchmarks/test_rag.py b/benchmarks/test_rag.py index aa67048..193ee74 100644 --- a/benchmarks/test_rag.py +++ b/benchmarks/test_rag.py @@ -21,14 +21,14 @@ def test_rag_pedantic(benchmark): """ Benchmark test for RAG performance using pytest-benchmark's pedantic mode. - + This test: 1. Initializes the RAG system with nilDB configuration 2. Sets up test parameters (prompt, number of chunks, clusters) 3. Performs a warm-up phase to stabilize measurements 4. Runs the benchmark with multiple iterations and rounds 5. Verifies the result is a list - + Args: benchmark: pytest-benchmark fixture for performance testing """ @@ -39,13 +39,15 @@ def test_rag_pedantic(benchmark): subtract_query_id = os.getenv("QUERY_ID") # Setup RAG instance - rag = asyncio.run(RAGVault.create( - ORG_CONFIG["nodes"], - ORG_CONFIG["org_credentials"], - schema_id=schema_id, - clusters_schema_id=clusters_schema_id, - subtract_query_id=subtract_query_id, - )) + rag = asyncio.run( + RAGVault.create( + ORG_CONFIG["nodes"], + ORG_CONFIG["org_credentials"], + schema_id=schema_id, + clusters_schema_id=clusters_schema_id, + subtract_query_id=subtract_query_id, + ) + ) prompt = "Who is Michelle Ross?" num_chunks = 2 @@ -54,23 +56,23 @@ def test_rag_pedantic(benchmark): def sync_runner(): """ Synchronous wrapper for the async RAG execution. - + This function: 1. Wraps the async top_num_chunks_execute in a synchronous context 2. Executes the RAG query with the configured parameters 3. Returns the retrieved chunks - + Returns: list: Retrieved chunks from the RAG system """ - return asyncio.run(rag.top_num_chunks_execute( - prompt, num_chunks, False, num_clusters - )) + return asyncio.run( + rag.top_num_chunks_execute(prompt, num_chunks, False, num_clusters) + ) # Warm up for _ in range(10): sync_runner() - #Actual benchmark + # Actual benchmark result = benchmark.pedantic(sync_runner, iterations=10, rounds=5) - assert isinstance(result, list) \ No newline at end of file + assert isinstance(result, str) diff --git a/src/nilrag/rag_vault.py b/src/nilrag/rag_vault.py index 63ed15c..67a9a50 100644 --- a/src/nilrag/rag_vault.py +++ b/src/nilrag/rag_vault.py @@ -93,6 +93,51 @@ async def create( await self.init() return self + @classmethod + async def create_from_dict( + cls, + config: dict, + *args, + **kwargs, + ) -> "RAGVault": + """ + Create a RAGVault instance from a dictionary + """ + # Check required keys explicitly + if "nodes" not in config: + raise ValueError("Missing required 'nodes' field in configuration") + if "org_secret_key" not in config or "org_did" not in config: + raise ValueError("Missing 'org_secret_key' or 'org_did' in configuration") + + nodes = [] + for node_data in config["nodes"]: + nodes.append({"url": node_data["url"], "did": node_data["did"]}) + credentials = { + "secret_key": config["org_secret_key"], + "org_did": config["org_did"], + } + + # Extract optional fields + with_clustering = config.get("with_clustering", None) + schema_id = config.get("schema_id", None) + clusters_schema_id = config.get("clusters_schema_id", None) + subtract_query_id = config.get("subtract_query_id", None) + + # Construct object synchronously + self = cls( + nodes, + credentials, + *args, + schema_id=schema_id, + with_clustering=with_clustering, + clusters_schema_id=clusters_schema_id, + subtract_query_id=subtract_query_id, + **kwargs, + ) + # Perform async initialization from SecretVaultWrapper (await SecretVaultWrapper.init()) + await self.init() + return self + @classmethod async def bootstrap( cls, @@ -337,6 +382,14 @@ async def top_num_chunks_execute( {"_id": id, "chunks": nilql.decrypt(xor_key, chunk_shares)} for id, chunk_shares in chunk_shares_by_id.items() ] + # 5: Format top results + formatted_results, formatted_results_time_sec = benchmark_time( + lambda: "\n".join( + f"- {str(result['chunks'])}" for result in top_num_chunks + ), + enable=enable_benchmark, + ) + relevant_context = f"\n\nRelevant Context:\n{formatted_results}" # Print benchmarks, if enabled if enable_benchmark: @@ -352,9 +405,10 @@ async def top_num_chunks_execute( \n decrypt: {decrypt_time_sec:.2f} seconds\ \n get top chunks ids: {top_num_chunks_ids_time_sec:.2f} seconds\ \n query top chunks: {query_top_chunks_time_sec:.2f} seconds\ + \n format top num chunks: {formatted_results_time_sec:.2f} seconds\ """ ) - return top_num_chunks + return relevant_context def nilai_chat_completion( self, diff --git a/test/rag.py b/test/rag.py index d8e0b7f..1df4aa8 100644 --- a/test/rag.py +++ b/test/rag.py @@ -279,7 +279,7 @@ def test_rag_with_nilql(self): self.check_top_results(top_results, case.expected_results) @unittest.skipUnless(RUN_OPTIONAL_TESTS, "Skipping optional test.") - async def test_top_num_chunks_execute(self): + async def test_1_top_num_chunks_execute(self): """ Test the RAG method with nilDB. """ @@ -308,12 +308,43 @@ async def test_top_num_chunks_execute(self): print(json.dumps(top_chunks, indent=4)) print(f"Query took {end_time - start_time:.2f} seconds") - # Format top results as nilAI - formatted_results = "\n".join( - f"- {str(result['chunks'])}" for result in top_chunks - ) + print(f"Relevant Context:\n{top_chunks}") + + @unittest.skipUnless(RUN_OPTIONAL_TESTS, "Skipping optional test.") + async def test_2_top_num_chunks_execute(self): + """ + Test the RAG method with nilDB. + RAGVault is created from a dictionary + """ + + # Load environment variables + load_dotenv(override=True) + + schema_id = os.getenv("SCHEMA_ID") + clusters_schema_id = os.getenv("CLUSTERS_SCHEMA_ID") + subtract_query_id = os.getenv("QUERY_ID") + + # Build config dictionary for create_from_dict + config = { + "nodes": ORG_CONFIG["nodes"], + "org_secret_key": ORG_CONFIG["org_credentials"]["secret_key"], + "org_did": ORG_CONFIG["org_credentials"]["org_did"], + "schema_id": schema_id, + "clusters_schema_id": clusters_schema_id, + "subtract_query_id": subtract_query_id, + } + # Initialize vault with clustering enabled + rag = await RAGVault.create_from_dict(config) + + print("Perform nilRAG...") + start_time = time.time() + query = DEFAULT_PROMPT + top_chunks = await rag.top_num_chunks_execute(query, 2) + end_time = time.time() + print(json.dumps(top_chunks, indent=4)) + print(f"Query took {end_time - start_time:.2f} seconds") - print(f"Relevant Context:\n{formatted_results}") + print(f"Relevant Context:\n{top_chunks}") if __name__ == "__main__":