Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,33 @@ jobs:
path: coverage.xml

# ---------------------------------------------------------------
# Job 2 — Build Docker images (only when Docker files change)
# Job 0 — Detect whether Docker-relevant files changed
# ---------------------------------------------------------------
changes:
runs-on: ubuntu-latest
outputs:
docker: ${{ steps.filter.outputs.docker }}
steps:
- uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
docker:
- 'docker/**'
- 'pyproject.toml'
- 'uv.lock'
- 'requirements.txt'
- '.github/workflows/ci.yml'

# ---------------------------------------------------------------
# Job 2 — Build Docker images (only when Docker-relevant files change)
# ---------------------------------------------------------------
docker-build:
runs-on: ubuntu-latest
needs: lint-and-test
needs: [lint-and-test, changes]
if: >
github.event_name == 'pull_request' ||
needs.changes.outputs.docker == 'true' ||
contains(github.event.head_commit.message, '[docker]')

steps:
Expand Down
Binary file added assets/Langfuse.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"plotly>=5.24",
"langgraph>=1.1.3",
"langchain-google-genai>=4.2.1",
"langchain-openai>=1.1.0",
"langchain-core>=1.2.22",
"langfuse>=4.3.1",
"ragas>=0.4.3",
Expand Down
6 changes: 6 additions & 0 deletions src/pharmagraphrag/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def query(req: QueryRequest) -> QueryResponse:
has_graph_context=result.context.has_graph,
has_vector_context=result.context.has_vector,
sources=sources,
graph_context=(result.context.graph_context or "")
if req.include_full_context
else None,
vector_context=(result.context.vector_context or "")
if req.include_full_context
else None,
llm_model=llm_model,
llm_provider=llm_provider,
error=error,
Expand Down
15 changes: 15 additions & 0 deletions src/pharmagraphrag/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class QueryRequest(BaseModel):
default=None,
description="LLM model to use (e.g. 'gemini-2.5-flash'). Uses server default if not set.",
)
include_full_context: bool = Field(
False,
description="Include full graph_context and vector_context in the response. "
"Disabled by default to keep payloads small; enable for evaluation or debugging.",
)


# ---------------------------------------------------------------------------
Expand All @@ -58,6 +63,16 @@ class QueryResponse(BaseModel):
has_graph_context: bool = False
has_vector_context: bool = False
sources: list[SourceInfo] = Field(default_factory=list)
graph_context: str | None = Field(
None,
description="Full graph context passed to the LLM. Only populated when "
"include_full_context=True in the request (for evaluation/debugging).",
)
vector_context: str | None = Field(
None,
description="Full vector context passed to the LLM. Only populated when "
"include_full_context=True in the request (for evaluation/debugging).",
)
llm_model: str = ""
llm_provider: str = ""
error: str | None = None
Expand Down
84 changes: 54 additions & 30 deletions src/pharmagraphrag/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,41 +41,47 @@ def scores(self) -> dict[str, float]:
def _get_evaluator_llm(model: str = "gemini-2.5-flash"):
"""Create a RAGAS-compatible LLM wrapper using Gemini via OpenAI compatibility.

Uses the Gemini OpenAI-compatible endpoint to avoid instructor/google-genai
SDK conflicts with safety settings.
Uses LangChain's ChatOpenAI pointed at the Gemini OpenAI-compatible
endpoint, wrapped in LangchainLLMWrapper. Bumps max_tokens to avoid
truncation on multi-statement classification prompts (ContextRecall).
"""
import os

from openai import OpenAI
from ragas.llms import llm_factory
from langchain_openai import ChatOpenAI
from ragas.llms import LangchainLLMWrapper

api_key = os.environ.get("GEMINI_API_KEY", "")
if not api_key:
raise ValueError("GEMINI_API_KEY env var is required for RAGAS evaluation")

client = OpenAI(
chat = ChatOpenAI(
model=model,
api_key=api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
max_tokens=8192,
temperature=0.0,
)
return llm_factory(model, provider="openai", client=client)
return LangchainLLMWrapper(chat)
Comment on lines +50 to +64
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_evaluator_llm() now imports langchain_openai.ChatOpenAI, but langchain-openai is not listed in pyproject.toml dependencies. A fresh install will fail at runtime when running evaluation. Add langchain-openai (and ensure compatible openai dependency) to the project dependencies or switch to an evaluator LLM wrapper that’s already included in the dependency set.

Copilot uses AI. Check for mistakes.


def _get_evaluator_embeddings(model: str = "text-embedding-004"):
"""Create RAGAS-compatible embeddings using Gemini via OpenAI compatibility."""
def _get_evaluator_embeddings(model: str = "models/gemini-embedding-001"):
"""Create RAGAS-compatible embeddings using Gemini native API.

The Gemini OpenAI-compatibility endpoint does not support embeddings
(returns 501 UNIMPLEMENTED), so we use the native Google Generative AI
embeddings via langchain-google-genai.
"""
import os

from openai import OpenAI
from ragas.embeddings import OpenAIEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from ragas.embeddings import LangchainEmbeddingsWrapper

api_key = os.environ.get("GEMINI_API_KEY", "")
if not api_key:
raise ValueError("GEMINI_API_KEY env var is required for RAGAS evaluation")

client = OpenAI(
api_key=api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
return OpenAIEmbeddings(client=client, model=model)
embeddings = GoogleGenerativeAIEmbeddings(model=model, google_api_key=api_key)
return LangchainEmbeddingsWrapper(embeddings)


def get_reference_free_metrics(
Expand Down Expand Up @@ -105,15 +111,23 @@ def get_reference_metrics(

Returns context precision, context recall, and answer correctness.
"""
from ragas.metrics import AnswerCorrectness, ContextPrecision, ContextRecall
from ragas.metrics import (
AnswerCorrectness,
AnswerSimilarity,
ContextPrecision,
ContextRecall,
)

llm = llm or _get_evaluator_llm()
embeddings = embeddings or _get_evaluator_embeddings()

# AnswerCorrectness needs an explicit AnswerSimilarity (embeddings-based)
answer_similarity = AnswerSimilarity(embeddings=embeddings)

return [
ContextPrecision(llm=llm),
ContextRecall(llm=llm),
AnswerCorrectness(llm=llm, embeddings=embeddings),
AnswerCorrectness(llm=llm, embeddings=embeddings, answer_similarity=answer_similarity),
]


Expand Down Expand Up @@ -141,23 +155,33 @@ def score_sample(
llm = llm or _get_evaluator_llm()
embeddings = embeddings or _get_evaluator_embeddings()

import asyncio

from ragas.dataset_schema import SingleTurnSample

metrics = get_reference_free_metrics(llm, embeddings)
if reference:
metrics.extend(get_reference_metrics(llm, embeddings))

results = []
for metric in metrics:
try:
score = metric.single_score(
user_input=question,
response=answer,
retrieved_contexts=contexts,
reference=reference,
)
results.append(MetricResult(name=type(metric).__name__, score=score))
except Exception as exc:
logger.warning("Metric {} failed: {}", type(metric).__name__, exc)
results.append(MetricResult(name=type(metric).__name__, score=-1.0))
sample = SingleTurnSample(
user_input=question,
response=answer,
retrieved_contexts=contexts,
reference=reference,
)

async def _score_all() -> list[MetricResult]:
out: list[MetricResult] = []
for metric in metrics:
try:
score = await metric.single_turn_ascore(sample)
out.append(MetricResult(name=type(metric).__name__, score=float(score)))
except Exception as exc:
logger.warning("Metric {} failed: {}", type(metric).__name__, exc)
out.append(MetricResult(name=type(metric).__name__, score=-1.0))
return out

results = asyncio.run(_score_all())

return EvalResult(
question=question,
Expand Down
19 changes: 14 additions & 5 deletions src/pharmagraphrag/evaluation/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PipelineResponse:
def _call_classic(question: str, config: RunConfig) -> PipelineResponse:
"""Call the classic pipeline via POST /query."""
url = f"{config.api_url}/query"
body: dict[str, Any] = {"question": question}
body: dict[str, Any] = {"question": question, "include_full_context": True}
if config.model:
body["model"] = config.model

Expand All @@ -58,10 +58,19 @@ def _call_classic(question: str, config: RunConfig) -> PipelineResponse:
data = resp.json()

contexts = []
for src in data.get("sources", []):
snippet = src.get("snippet", "")
if snippet:
contexts.append(snippet)
# Prefer full graph/vector context (the text the LLM actually saw)
graph_ctx = data.get("graph_context", "")
if graph_ctx:
contexts.append(graph_ctx)
vector_ctx = data.get("vector_context", "")
if vector_ctx:
contexts.append(vector_ctx)
# Fallback to snippets (older API versions)
if not contexts:
for src in data.get("sources", []):
snippet = src.get("snippet", "")
if snippet:
contexts.append(snippet)

return PipelineResponse(
answer=data.get("answer", ""),
Expand Down
28 changes: 23 additions & 5 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
from pathlib import Path
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand Down Expand Up @@ -207,7 +207,7 @@ class TestScoreSample:
@patch("pharmagraphrag.evaluation.metrics.get_reference_free_metrics")
def test_score_without_reference(self, mock_free, mock_llm, mock_emb):
mock_metric = MagicMock()
mock_metric.single_score.return_value = 0.85
mock_metric.single_turn_ascore = AsyncMock(return_value=0.85)
type(mock_metric).__name__ = "Faithfulness"
mock_free.return_value = [mock_metric]

Expand All @@ -226,12 +226,12 @@ def test_score_without_reference(self, mock_free, mock_llm, mock_emb):
@patch("pharmagraphrag.evaluation.metrics.get_reference_free_metrics")
def test_score_with_reference(self, mock_free, mock_ref, mock_llm, mock_emb):
mock_metric1 = MagicMock()
mock_metric1.single_score.return_value = 0.9
mock_metric1.single_turn_ascore = AsyncMock(return_value=0.9)
type(mock_metric1).__name__ = "Faithfulness"
mock_free.return_value = [mock_metric1]

mock_metric2 = MagicMock()
mock_metric2.single_score.return_value = 0.7
mock_metric2.single_turn_ascore = AsyncMock(return_value=0.7)
type(mock_metric2).__name__ = "ContextRecall"
mock_ref.return_value = [mock_metric2]

Expand All @@ -250,7 +250,7 @@ def test_score_with_reference(self, mock_free, mock_ref, mock_llm, mock_emb):
@patch("pharmagraphrag.evaluation.metrics.get_reference_free_metrics")
def test_score_handles_metric_error(self, mock_free, mock_llm, mock_emb):
mock_metric = MagicMock()
mock_metric.single_score.side_effect = RuntimeError("API error")
mock_metric.single_turn_ascore = AsyncMock(side_effect=RuntimeError("API error"))
type(mock_metric).__name__ = "Faithfulness"
mock_free.return_value = [mock_metric]

Expand Down Expand Up @@ -304,6 +304,24 @@ def test_successful_call(self, mock_post):
assert len(resp.contexts) == 2
assert resp.error is None

@patch("pharmagraphrag.evaluation.runner.httpx.post")
def test_prefers_full_contexts_over_snippets(self, mock_post):
mock_response = MagicMock()
mock_response.json.return_value = {
"answer": "ok",
"graph_context": "Drug: ASPIRIN\nAdverse events: DYSPNOEA 489",
"vector_context": "Full label text for aspirin...",
"sources": [{"snippet": "truncated snippet", "type": "graph"}],
}
mock_response.raise_for_status = MagicMock()
mock_post.return_value = mock_response

resp = _call_classic("q?", RunConfig(api_url="http://t"))

assert len(resp.contexts) == 2
assert "DYSPNOEA 489" in resp.contexts[0]
assert "Full label text" in resp.contexts[1]

@patch("pharmagraphrag.evaluation.runner.httpx.post")
def test_handles_error(self, mock_post):
mock_post.side_effect = Exception("Connection refused")
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading