diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7df414b..53f0415 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,6 +50,4 @@ jobs: run: poetry install --with dev - name: Run unit tests (mocked, no API keys required) - run: | - poetry run pytest tests/test_cache.py tests/test_unit.py tests/test_client_mock.py tests/test_persistence.py \ - -v --tb=short --ignore=tests/test_client.py --ignore=tests/test_revision.py --ignore=tests/test_steps.py \ No newline at end of file + run: poetry run pytest tests/ -v --tb=short -m "not requires_api" \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index e01d2fe..6928ed1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1418,6 +1418,21 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7.0,<5.0.0" uuid-utils = ">=0.12.0,<1.0" +[[package]] +name = "langchain-text-splitters" +version = "0.3.11" +description = "LangChain text splitting utilities" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "langchain_text_splitters-0.3.11-py3-none-any.whl", hash = "sha256:cf079131166a487f1372c8ab5d0bfaa6c0a4291733d9c43a34a16ac9bcd6a393"}, + {file = "langchain_text_splitters-0.3.11.tar.gz", hash = "sha256:7a50a04ada9a133bbabb80731df7f6ddac51bc9f1b9cab7fa09304d71d38a6cc"}, +] + +[package.dependencies] +langchain-core = ">=0.3.75,<2.0.0" + [[package]] name = "langgraph" version = "0.2.76" @@ -5736,4 +5751,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "78a1913ea9d75e2d7db2fae245c0091bf5cd9eda9ef215a5c08ba75ea69b7f9e" +content-hash = "d9aa0279a7cf8c5acb34c5392f2a0bd6669ba7f07d7375591419a4fb01cb8bfd" diff --git a/pyproject.toml b/pyproject.toml index 90fe2f1..426ea9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ pydantic = "^2.7" pydantic-settings = "^2.3" langgraph = "^0.2.0" langchain-core = "^0.3.0" +langchain-text-splitters = "^0.3.0" streamlit = "^1.36" chromadb = "^0.5.0" sentence-transformers = "^3.0.0" @@ -47,6 +48,9 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] addopts = "-q --disable-warnings --maxfail=1" testpaths = ["tests"] +markers = [ + "requires_api: tests that require API keys (OpenAI, Anthropic)" +] [tool.ruff] line-length = 100 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 7855e0d..1eafc7a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ pydantic>=2.7 pydantic-settings>=2.3 langgraph>=0.2.0 langchain-core>=0.3.0 +langchain-text-splitters>=0.3.0 streamlit>=1.36 chromadb>=0.5.0 sentence-transformers>=3.0.0 diff --git a/tests/test_client.py b/tests/test_client.py index 9859802..e67edc9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,8 @@ from utils.retriever_utils import chroma_retriever from utils.vector_types import chroma_params +pytestmark = pytest.mark.requires_api # All tests in this file need API keys + @pytest.fixture def client(): diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 0000000..262daef --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,373 @@ +""" +Tests for RAG retriever utilities. +All tests are mocked to run without API keys or external services. +""" + +import pytest +from unittest.mock import Mock, patch +from utils.vector_types import chroma_params +from tf_types import RetrievedChunk + + +class TestChromaParams: + """Tests for chroma_params dataclass.""" + + def test_chroma_params_defaults(self): + """chroma_params should have sensible defaults.""" + params = chroma_params() + + assert params.documents is None + assert params.collection == "traceflow-kb" + assert params.directory == "./chroma_db" + + def test_chroma_params_custom_values(self): + """chroma_params should accept custom values.""" + params = chroma_params( + documents=["doc1", "doc2"], collection="my-collection", directory="/custom/path" + ) + + assert params.documents == ["doc1", "doc2"] + assert params.collection == "my-collection" + assert params.directory == "/custom/path" + + def test_chroma_params_empty_documents(self): + """chroma_params should accept empty document list.""" + params = chroma_params(documents=[]) + assert params.documents == [] + + +class TestChromaRetriever: + """Tests for chroma_retriever class.""" + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_init_local_mode(self, mock_openai, mock_chroma): + """Retriever should initialize in local mode with params.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + params = chroma_params(collection="test-coll", directory="./test_db") + retriever = chroma_retriever(local=True, params=params) + + mock_chroma.assert_called_once_with(path="./test_db") + mock_chroma.return_value.get_or_create_collection.assert_called_once_with(name="test-coll") + assert retriever.collection == mock_collection + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_init_local_mode_requires_params(self, mock_openai, mock_chroma): + """Retriever should raise error if local mode without params.""" + from utils.retriever_utils import chroma_retriever + + with pytest.raises(ValueError, match="Chroma parameters must be provided"): + chroma_retriever(local=True, params=None) + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_create_document_batches(self, mock_openai, mock_chroma): + """create_document_batches should yield correct batch sizes.""" + from utils.retriever_utils import chroma_retriever + + mock_chroma.return_value.get_or_create_collection.return_value = Mock() + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + documents = ["doc1", "doc2", "doc3", "doc4", "doc5"] + batches = list(retriever.create_document_batches(documents, batch_size=2)) + + assert len(batches) == 3 + assert batches[0] == ["doc1", "doc2"] + assert batches[1] == ["doc3", "doc4"] + assert batches[2] == ["doc5"] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_create_document_batches_exact_multiple(self, mock_openai, mock_chroma): + """create_document_batches should handle exact multiples.""" + from utils.retriever_utils import chroma_retriever + + mock_chroma.return_value.get_or_create_collection.return_value = Mock() + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + documents = ["doc1", "doc2", "doc3", "doc4"] + batches = list(retriever.create_document_batches(documents, batch_size=2)) + + assert len(batches) == 2 + assert batches[0] == ["doc1", "doc2"] + assert batches[1] == ["doc3", "doc4"] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_create_document_batches_empty(self, mock_openai, mock_chroma): + """create_document_batches should handle empty list.""" + from utils.retriever_utils import chroma_retriever + + mock_chroma.return_value.get_or_create_collection.return_value = Mock() + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + batches = list(retriever.create_document_batches([], batch_size=2)) + assert batches == [] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_create_vector_store(self, mock_openai, mock_chroma): + """create_vector_store should embed and store documents.""" + from utils.retriever_utils import chroma_retriever + + # Mock collection + mock_collection = Mock() + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + # Mock OpenAI embeddings + mock_embedding = Mock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + mock_openai.return_value.embeddings.create.return_value = Mock( + data=[mock_embedding, mock_embedding] + ) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + documents = ["doc1", "doc2"] + retriever.create_vector_store(documents, batch_size=2) + + # Verify embeddings were created + mock_openai.return_value.embeddings.create.assert_called_once_with( + model="text-embedding-3-small", input=["doc1", "doc2"] + ) + + # Verify documents were added to collection + mock_collection.add.assert_called_once() + call_args = mock_collection.add.call_args + assert call_args.kwargs["ids"] == ["0", "1"] + assert call_args.kwargs["documents"] == ["doc1", "doc2"] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_create_vector_store_multiple_batches(self, mock_openai, mock_chroma): + """create_vector_store should handle multiple batches with correct IDs.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + # Mock embeddings for each batch + mock_embedding = Mock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + mock_openai.return_value.embeddings.create.return_value = Mock( + data=[mock_embedding, mock_embedding] + ) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + documents = ["doc1", "doc2", "doc3", "doc4", "doc5"] + retriever.create_vector_store(documents, batch_size=2) + + # Should have 3 batches + assert mock_collection.add.call_count == 3 + + # Verify IDs are sequential across batches + calls = mock_collection.add.call_args_list + assert calls[0].kwargs["ids"] == ["0", "1"] + assert calls[1].kwargs["ids"] == ["2", "3"] + assert calls[2].kwargs["ids"] == ["4"] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_create_vector_store_progress_callback(self, mock_openai, mock_chroma): + """create_vector_store should call progress callback.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + mock_embedding = Mock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + mock_openai.return_value.embeddings.create.return_value = Mock( + data=[mock_embedding, mock_embedding] + ) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + # Track progress calls + progress_calls = [] + + def track_progress(current, total): + progress_calls.append((current, total)) + + documents = ["doc1", "doc2", "doc3", "doc4", "doc5"] + retriever.create_vector_store(documents, batch_size=2, progress_callback=track_progress) + + # Should have 3 progress updates + assert progress_calls == [(1, 3), (2, 3), (3, 3)] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_retrieve_similar_docs(self, mock_openai, mock_chroma): + """retrieve_similar_docs should query and return chunks.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_collection.query.return_value = { + "documents": [["Document 1 content", "Document 2 content"]], + "metadatas": [[{"source": "doc_0"}, {"source": "doc_1"}]], + "distances": [[0.1, 0.3]], + } + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + # Mock query embedding + mock_embedding = Mock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + mock_openai.return_value.embeddings.create.return_value = Mock(data=[mock_embedding]) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + results = retriever.retrieve_similar_docs("test query", n_results=2) + + # Verify query was made + mock_collection.query.assert_called_once() + + # Verify results + assert len(results) == 2 + assert isinstance(results[0], RetrievedChunk) + assert results[0].content == "Document 1 content" + assert results[0].source == "doc_0" + assert results[0].relevance_score == pytest.approx(0.95, rel=0.01) # 1 - 0.1/2 + assert results[1].relevance_score == pytest.approx(0.85, rel=0.01) # 1 - 0.3/2 + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_retrieve_similar_docs_empty_results(self, mock_openai, mock_chroma): + """retrieve_similar_docs should handle empty results.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_collection.query.return_value = { + "documents": [[]], + "metadatas": [[]], + "distances": [[]], + } + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + mock_embedding = Mock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + mock_openai.return_value.embeddings.create.return_value = Mock(data=[mock_embedding]) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + results = retriever.retrieve_similar_docs("test query") + + assert results == [] + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_chunk_documents(self, mock_openai, mock_chroma): + """chunk_documents should split documents using text splitter.""" + from utils.retriever_utils import chroma_retriever + from langchain_core.documents import Document + + mock_chroma.return_value.get_or_create_collection.return_value = Mock() + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + # Create a document that will be split + long_text = "This is a test. " * 200 # ~3200 chars + documents = [Document(page_content=long_text)] + + chunks = retriever.chunk_documents(documents) + + # Should be split into multiple chunks + assert len(chunks) > 1 + + +class TestRetrievedChunk: + """Tests for RetrievedChunk dataclass.""" + + def test_retrieved_chunk_creation(self): + """RetrievedChunk should be created with all fields.""" + chunk = RetrievedChunk( + chunk_id="chunk_1", content="Test content", source="test_doc.txt", relevance_score=0.95 + ) + + assert chunk.chunk_id == "chunk_1" + assert chunk.content == "Test content" + assert chunk.source == "test_doc.txt" + assert chunk.relevance_score == 0.95 + + +class TestVectorStoreIntegration: + """Integration-style tests for vector store operations.""" + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_full_index_and_retrieve_flow(self, mock_openai, mock_chroma): + """Test full flow: create vector store then retrieve.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_collection.query.return_value = { + "documents": [["AI is artificial intelligence"]], + "metadatas": [[{"source": "doc_0"}]], + "distances": [[0.2]], + } + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + # Mock embeddings + mock_embedding = Mock() + mock_embedding.embedding = [0.1] * 1536 + mock_openai.return_value.embeddings.create.return_value = Mock(data=[mock_embedding]) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + # Index documents + retriever.create_vector_store(["AI is artificial intelligence"]) + + # Retrieve + results = retriever.retrieve_similar_docs("What is AI?") + + assert len(results) == 1 + assert "artificial intelligence" in results[0].content + + @patch("utils.retriever_utils.chromadb.PersistentClient") + @patch("utils.retriever_utils.OpenAI") + def test_batch_size_affects_api_calls(self, mock_openai, mock_chroma): + """Smaller batch size should result in more API calls.""" + from utils.retriever_utils import chroma_retriever + + mock_collection = Mock() + mock_chroma.return_value.get_or_create_collection.return_value = mock_collection + + mock_embedding = Mock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + mock_openai.return_value.embeddings.create.return_value = Mock(data=[mock_embedding]) + + params = chroma_params(collection="test", directory="./test") + retriever = chroma_retriever(local=True, params=params) + + documents = ["doc1", "doc2", "doc3", "doc4"] + + # With batch_size=1, should make 4 API calls + retriever.create_vector_store(documents, batch_size=1) + assert mock_openai.return_value.embeddings.create.call_count == 4 + + # Reset + mock_openai.return_value.embeddings.create.reset_mock() + mock_collection.reset_mock() + + # With batch_size=4, should make 1 API call + mock_openai.return_value.embeddings.create.return_value = Mock(data=[mock_embedding] * 4) + retriever.create_vector_store(documents, batch_size=4) + assert mock_openai.return_value.embeddings.create.call_count == 1 diff --git a/tests/test_revision.py b/tests/test_revision.py index c38d5ef..20ee868 100644 --- a/tests/test_revision.py +++ b/tests/test_revision.py @@ -4,6 +4,8 @@ from client import TraceFlowClient from tf_types import RunConfig, Mode, Strictness, RetrievedChunk +pytestmark = pytest.mark.requires_api # All tests in this file need API keys + @pytest.fixture def client(): diff --git a/tests/test_steps.py b/tests/test_steps.py index fa57c0f..930f932 100644 --- a/tests/test_steps.py +++ b/tests/test_steps.py @@ -2,6 +2,8 @@ from client import TraceFlowClient from tf_types import RunConfig, Mode +pytestmark = pytest.mark.requires_api # All tests in this file need API keys + @pytest.fixture def client(): diff --git a/ui/app.py b/ui/app.py index 3446ba2..43d9ab8 100644 --- a/ui/app.py +++ b/ui/app.py @@ -13,11 +13,19 @@ from client import TraceFlowClient from tf_types import Mode, RunConfig, Strictness from datetime import datetime +import os + +# Optional RAG imports +try: + from utils.retriever_utils import chroma_retriever + from utils.vector_types import chroma_params + + RAG_AVAILABLE = True +except ImportError: + RAG_AVAILABLE = False # Page config -st.set_page_config( - page_title="TraceFlow Lite", page_icon="🔍", layout="wide", initial_sidebar_state="expanded" -) +st.set_page_config(page_title="TraceFlow Lite", layout="wide", initial_sidebar_state="expanded") # Custom CSS for modern look st.markdown( @@ -181,6 +189,15 @@ def get_client(): return TraceFlowClient() +@st.cache_resource +def get_retriever(collection_name: str, db_path: str): + """Get or create a cached retriever instance.""" + if not RAG_AVAILABLE: + return None + params = chroma_params(collection=collection_name, directory=db_path) + return chroma_retriever(local=True, params=params) + + def format_timestamp(dt: datetime) -> str: """Format datetime for display.""" if dt is None: @@ -203,7 +220,7 @@ def get_mode_badge(mode: str) -> str: def render_sidebar(): """Render sidebar with navigation and new run form.""" with st.sidebar: - st.markdown('

🔍 TraceFlow

', unsafe_allow_html=True) + st.markdown('

TraceFlow

', unsafe_allow_html=True) st.markdown( '

Agent Observability Platform

', unsafe_allow_html=True ) @@ -213,7 +230,7 @@ def render_sidebar(): # Navigation page = st.radio( "Navigation", - ["🚀 New Run", "📋 Trace History", "📊 Analytics"], + ["New Run", "Trace History", "Analytics"], label_visibility="collapsed", ) @@ -235,7 +252,7 @@ def render_sidebar(): def render_new_run_page(): """Render the new run page.""" - st.markdown("## 🚀 New Run") + st.markdown("## New Run") st.markdown("Execute a new agent workflow with custom configuration.") col1, col2 = st.columns([2, 1]) @@ -252,16 +269,16 @@ def render_new_run_page(): "Mode", [Mode.GROUNDED_QA, Mode.TRIAGE_PLAN, Mode.CHANGE_SAFETY], format_func=lambda x: { - Mode.GROUNDED_QA: "🎯 Grounded QA", - Mode.TRIAGE_PLAN: "📝 Triage Plan", - Mode.CHANGE_SAFETY: "🛡️ Change Safety", + Mode.GROUNDED_QA: "Grounded QA", + Mode.TRIAGE_PLAN: "Triage Plan", + Mode.CHANGE_SAFETY: "Change Safety", }.get(x, x.value), ) provider = st.selectbox( "Provider", ["openai", "anthropic"], - format_func=lambda x: "🤖 OpenAI" if x == "openai" else "🧠 Anthropic", + format_func=lambda x: "OpenAI" if x == "openai" else "Anthropic", ) model_options = { @@ -293,11 +310,126 @@ def render_new_run_page(): help="Cache responses to save cost on repeated queries", ) - if st.button("▶️ Execute Run", use_container_width=True, type="primary"): + # RAG Configuration Section + st.divider() + st.markdown("### RAG Configuration (Optional)") + + enable_rag = st.checkbox( + "Enable RAG", + value=False, + help="Use Retrieval-Augmented Generation with your own documents", + ) + + retriever_fn = None + top_k = 5 # Default value + if enable_rag: + if not RAG_AVAILABLE: + st.error("RAG dependencies not available. Install chromadb and sentence-transformers.") + elif not os.getenv("OPENAI_API_KEY"): + st.warning("OPENAI_API_KEY not set. Required for embeddings.") + else: + # Document input + doc_input_method = st.radio( + "Document Input Method", + ["Paste Text", "Upload Files"], + horizontal=True, + ) + + documents = [] + if doc_input_method == "Paste Text": + doc_text = st.text_area( + "Paste your documents (one per line or separated by blank lines)", + height=150, + placeholder="Document 1 content here...\n\nDocument 2 content here...", + ) + if doc_text.strip(): + # Split by double newlines or treat as single doc + documents = [d.strip() for d in doc_text.split("\n\n") if d.strip()] + else: + uploaded_files = st.file_uploader( + "Upload text files", + type=["txt", "md"], + accept_multiple_files=True, + ) + for file in uploaded_files: + content = file.read().decode("utf-8") + documents.append(content) + + # Vector store settings + col1, col2 = st.columns(2) + with col1: + collection_name = st.text_input( + "Collection Name", + value="traceflow_docs", + help="Name for the vector store collection", + ) + with col2: + top_k = st.slider("Top K Results", 1, 10, 5, help="Number of chunks to retrieve") + + # Initialize/update vector store + if documents: + st.info(f"{len(documents)} document(s) ready") + + if st.button("Create/Update Vector Store"): + try: + db_path = "./chroma_db" + retriever = get_retriever(collection_name, db_path) + + # Create progress bar + progress_bar = st.progress(0, text="Creating embeddings...") + + def update_progress(current, total): + progress = current / total + progress_bar.progress( + progress, text=f"Processing batch {current}/{total}..." + ) + + retriever.create_vector_store(documents, progress_callback=update_progress) + + progress_bar.progress(1.0, text="Complete!") + st.session_state["rag_ready"] = True + st.session_state["rag_collection"] = collection_name + st.success(f"Vector store created with {len(documents)} documents!") + except Exception as e: + st.error(f"Failed to create vector store: {e}") + + # Check if RAG is ready + if ( + st.session_state.get("rag_ready") + and st.session_state.get("rag_collection") == collection_name + ): + st.success("RAG is ready! Retriever will be used in queries.") + retriever = get_retriever(collection_name, "./chroma_db") + retriever_fn = retriever.retrieve_similar_docs + elif st.session_state.get("rag_collection"): + # Collection exists from previous session + try: + retriever = get_retriever(collection_name, "./chroma_db") + # Test if collection has data + if retriever.collection.count() > 0: + st.success( + f"Using existing collection '{collection_name}' " + f"({retriever.collection.count()} chunks)" + ) + retriever_fn = retriever.retrieve_similar_docs + st.session_state["rag_ready"] = True + st.session_state["rag_collection"] = collection_name + except Exception: + pass + + if st.button("Execute Run", use_container_width=True, type="primary"): if not user_input.strip(): st.error("Please enter a user input.") return + # Validate RAG setup if enabled + if enable_rag and retriever_fn is None: + st.error( + "RAG is enabled but vector store is not ready. " + "Please create the vector store first." + ) + return + with st.spinner("Running workflow..."): client = get_client() config = RunConfig( @@ -310,6 +442,8 @@ def render_new_run_page(): max_latency_ms=max_latency, max_revisions=max_revisions, enable_cache=enable_cache, + retriever_fn=retriever_fn, + top_k=top_k if enable_rag else 5, ) result = client.run(user_input, config) @@ -318,9 +452,9 @@ def render_new_run_page(): st.divider() if result.status.value == "done": - st.success("✅ Run completed successfully!") + st.success("Run completed successfully!") else: - st.error(f"❌ Run failed: {result.err}") + st.error(f"Run failed: {result.err}") # Result card st.markdown("### Result") @@ -389,7 +523,7 @@ def render_new_run_page(): def render_trace_history_page(): """Render the trace history page.""" - st.markdown("## 📋 Trace History") + st.markdown("## Trace History") st.markdown("View and inspect previous agent runs.") client = get_client() @@ -454,12 +588,12 @@ def render_trace_history_page(): ) with col2: - if st.button("🔍 Details", key=f"view_{trace.trace_id}"): + if st.button("Details", key=f"view_{trace.trace_id}"): st.session_state["selected_trace_id"] = trace.trace_id st.rerun() with col3: - if st.button("🔄 Replay", key=f"replay_{trace.trace_id}"): + if st.button("Replay", key=f"replay_{trace.trace_id}"): with st.spinner("Replaying..."): result = client.replay(trace.trace_id) st.session_state["last_trace_id"] = result.trace_id @@ -469,7 +603,7 @@ def render_trace_history_page(): def render_trace_detail(trace_id: str): """Render detailed view of a trace.""" st.divider() - st.markdown("## 🔎 Trace Detail") + st.markdown("## Trace Detail") client = get_client() trace = client.get_trace(trace_id) @@ -492,7 +626,7 @@ def render_trace_detail(trace_id: str): unsafe_allow_html=True, ) with col2: - if st.button("❌ Close"): + if st.button("Close"): del st.session_state["selected_trace_id"] st.rerun() @@ -540,21 +674,21 @@ def render_trace_detail(trace_id: str): ) # Input/Output - st.markdown("### 📥 Input") + st.markdown("### Input") st.code(trace.user_input, language=None) - st.markdown("### 📤 Output") + st.markdown("### Output") if trace.final_answer: st.markdown(trace.final_answer) else: st.warning("No output generated.") if trace.error: - st.markdown("### ❌ Error") + st.markdown("### Error") st.error(trace.error) # Steps timeline - st.markdown("### 📊 Execution Steps") + st.markdown("### Execution Steps") steps = client.dbStore.get_steps(trace_id) if steps: @@ -575,7 +709,7 @@ def render_trace_detail(trace_id: str): # Steps list for step in steps: - cache_badge = " ⚡ Cached" if step.cache_hit else "" + cache_badge = " [Cached]" if step.cache_hit else "" with st.expander( f"**{step.step_seq + 1}. {step.node_name.upper()}**{cache_badge} — {step.latency_ms:.0f}ms" @@ -604,7 +738,7 @@ def render_trace_detail(trace_id: str): def render_analytics_page(): """Render analytics dashboard.""" - st.markdown("## 📊 Analytics") + st.markdown("## Analytics") st.markdown("Insights into your agent workflows.") client = get_client() @@ -740,11 +874,11 @@ def main(): """Main application entry point.""" page = render_sidebar() - if page == "🚀 New Run": + if page == "New Run": render_new_run_page() - elif page == "📋 Trace History": + elif page == "Trace History": render_trace_history_page() - elif page == "📊 Analytics": + elif page == "Analytics": render_analytics_page() diff --git a/utils/retriever_utils.py b/utils/retriever_utils.py index a89e781..8f81bb4 100644 --- a/utils/retriever_utils.py +++ b/utils/retriever_utils.py @@ -42,21 +42,38 @@ def create_document_batches(self, splits: list[str], batch_size: int): for i in range(0, len(splits), batch_size): yield splits[i : i + batch_size] - def create_vector_store(self, splits: list[str], batch_size: int = 1000): - for batch in tqdm( - self.create_document_batches(splits, batch_size), desc="embedding batches", unit="batch" - ): + def create_vector_store( + self, + splits: list[str], + batch_size: int = 1000, + progress_callback: callable = None, + ): + """Create vector store from document splits. + + Args: + splits: List of document chunks to embed + batch_size: Number of documents per batch + progress_callback: Optional callback(current, total) for progress updates + """ + batches = list(self.create_document_batches(splits, batch_size)) + total_batches = len(batches) + + for i, batch in enumerate(tqdm(batches, desc="embedding batches", unit="batch")): embeddings = self.client.embeddings.create(model="text-embedding-3-small", input=batch) embedding_vectors = [e.embedding for e in embeddings.data] self.collection.add( - ids=[str(i) for i in range(len(batch))], + ids=[str(i * batch_size + j) for j in range(len(batch))], embeddings=embedding_vectors, documents=batch, - metadatas=[{"source": f"doc_{i}"} for i in range(len(batch))], + metadatas=[{"source": f"doc_{i * batch_size + j}"} for j in range(len(batch))], ) + # Call progress callback if provided + if progress_callback: + progress_callback(i + 1, total_batches) + def retrieve_similar_docs(self, query: str, n_results: int = 5) -> list[RetrievedChunk]: query_response = self.client.embeddings.create( model="text-embedding-3-small", input=[query] diff --git a/utils/vector_types.py b/utils/vector_types.py index 85047b6..76988cb 100644 --- a/utils/vector_types.py +++ b/utils/vector_types.py @@ -2,6 +2,6 @@ class chroma_params(BaseModel): - documents: list[str] + documents: list[str] | None = None collection: str = "traceflow-kb" directory: str = "./chroma_db"