Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion client/src/pages/ProjectPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import {
fileUploadToBackend,
fileFetchFromBackend,
} from "../services/fileService";
import { JobStatus } from "../state/types";
import { fetchTokenEstimation } from "../services/projectService";
import { JobStatus, TokenEstimation } from "../state/types";
import { ManualEvaluationModal } from "../components/ManualEvaluationModal";
import { Button } from "../components/Button";
import {
Expand All @@ -31,6 +32,7 @@ import {
CircleAlert,
CircleCheck,
CircleStop,
ChevronDown,
Download,
FileText,
Loader,
Expand Down Expand Up @@ -433,6 +435,7 @@ export const ProjectPage = () => {
Array<{ id: string; created: number; object: "model"; owned_by: string }>
>([]);


const loadingProjects = useTypedStoreState((state) => state.loading.projects);
const loadProjects = useTypedStoreActions((actions) => actions.fetchProjects);
const getProjectByUuid = useTypedStoreState(
Expand All @@ -450,6 +453,28 @@ export const ProjectPage = () => {

const project = getProjectByUuid(projectUuid);

const [tokenEstimation, setTokenEstimation] = useState<TokenEstimation | null>(null);
const [showTokenDetails, setShowTokenDetails] = useState(false);

const getEstimation = useCallback(async () => {
if (!projectUuid || fetchedFiles.length === 0) {
setTokenEstimation(null);
return;
}
fetchTokenEstimation(projectUuid)
.then((estimation) => {
setTokenEstimation(estimation)
})
.catch((error) => {
setTokenEstimation(null);
toast.error(`Token estimation failed: ${error instanceof Error ? error.message : String(error)}`);
})
}, [projectUuid, fetchedFiles.length])

useEffect(() => {
getEstimation();
}, [getEstimation]);

useEffect(() => {
if (project !== undefined) {
fetchPapers(projectUuid);
Expand Down Expand Up @@ -1148,6 +1173,47 @@ export const ProjectPage = () => {
</Button>
</div>
</Card>
<button
type="button"
onClick={() => tokenEstimation && setShowTokenDetails(!showTokenDetails)}
disabled={!tokenEstimation}
className={classNames(
"p-2 bg-white shadow-sm rounded-lg flex flex-col border border-transparent",
{ "hover:border-slate-200 hover:bg-gray-100 cursor-pointer": tokenEstimation }
)}
>
<div className="flex justify-between items-center w-full px-1">
<div className="flex items-center gap-2">
<span className="text-sm font-medium text-slate-700">Estimated tokens</span>
{tokenEstimation && (
<ChevronDown
size={14}
className={classNames("text-slate-400 transition-transform", { "rotate-180": showTokenDetails })}
/>
)}
</div>
<span className="text-sm font-mono font-bold text-slate-900">
{tokenEstimation ? `~${tokenEstimation.total_estimated_tokens}` : "?"}
</span>
</div>

{showTokenDetails && tokenEstimation && (
<div className="w-full pt-2 flex flex-col gap-1 px-1">
<div className="flex justify-between text-xs text-slate-500">
<span>Input tokens:</span>
<span className="font-mono">{tokenEstimation.estimated_input_tokens}</span>
</div>
<div className="flex justify-between text-xs text-slate-500">
<span>Output tokens:</span>
<span className="font-mono">{tokenEstimation.estimated_output_tokens}</span>
</div>
<div className="flex justify-between text-xs text-slate-500">
<span>Tasks:</span>
<span className="font-mono">{tokenEstimation.task_count}</span>
</div>
</div>
)}
</button>
</div>
</div>

Expand Down
14 changes: 14 additions & 0 deletions client/src/services/projectService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
DeletedProjectModel,
ProjectModel,
} from "../state/types/project";
import { TokenEstimation } from "../state/types";

export const fetch_projects = async (): Promise<Project[]> => {
try {
Expand Down Expand Up @@ -57,3 +58,16 @@ export const delete_project = async (uuid: string): Promise<DeletedProject> => {
throw error;
}
};

// TODO: why no camelcase for rest?
export const fetchTokenEstimation = async (
uuid: string,
): Promise<TokenEstimation> => {
try {
const res = await api.get(`/api/v1/project/${uuid}/estimate`);
return res.data;
} catch (error) {
console.error("Fetching project by UUID unsuccessful", error);
throw error;
}
};
25 changes: 16 additions & 9 deletions client/src/state/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ export type JobStats = {
};

export type JobWithStats = {
uuid: string;
id: string;
project_uuid: string;
prompting_config: PromptingConfig;
llm_config: LlmConfig;
created_at: Date | null;
updated_at: Date | null;
stats: JobStats;
}
uuid: string;
id: string;
project_uuid: string;
prompting_config: PromptingConfig;
llm_config: LlmConfig;
created_at: Date | null;
updated_at: Date | null;
stats: JobStats;
};

export type Paper = {
uuid: string;
Expand Down Expand Up @@ -151,6 +151,13 @@ export type Result = {
[modelName: string]: string;
};

export type TokenEstimation = {
task_count: number;
estimated_input_tokens: number;
estimated_output_tokens: number;
total_estimated_tokens: number;
};

// Keep this up-to-date with server/src/core/llm_providers.py
const ConfigParameterSchema = z.object({
key: z.string(),
Expand Down
1 change: 1 addition & 0 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"redis>=6.4.0",
"ruff>=0.14.9",
"sqlalchemy[asyncio]>=2.0.45",
"tiktoken>=0.12.0",
"uvicorn>=0.38.0",
]
[tool.setuptools]
Expand Down
45 changes: 42 additions & 3 deletions server/src/api/controllers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from src.db.db_context import DBContext, get_db_ctx
from src.event_queue import EventName, QueueItem, push_event
from src.schemas.project import ProjectCreate, ProjectRead
from src.services.paper_service import create_paper_service
from src.services.project_service import create_project_service
from src.services.token_estimation_service import (
TokenEstimation,
create_token_estimation_service,
)

router = APIRouter()

Expand Down Expand Up @@ -51,6 +56,42 @@ async def get_project(uuid: UUID, db_ctx: DBContext = Depends(get_db_ctx)):
)


@router.get(
"/project/{uuid}/estimate",
status_code=status.HTTP_200_OK,
response_model=TokenEstimation,
tags=["Project"],
)
async def estimate_tokens(uuid: UUID, db_ctx: DBContext = Depends(get_db_ctx)):
project_service = create_project_service(db_ctx)
paper_service = create_paper_service(db_ctx)
token_estimation_service = create_token_estimation_service()
try:
project = await project_service.fetch_by_uuid(uuid)
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)

papers = await paper_service.fetch_papers(project_uuid=uuid)
if not papers:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Papers not found"
)

estimation = await token_estimation_service.estimate_tokens(
papers=papers, criteria=project.criteria
)
return estimation
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch project: {str(e)}",
)


@router.post("/project", status_code=status.HTTP_201_CREATED, tags=["Project"])
async def create_new_project(
project_data: ProjectCreate, db_ctx: DBContext = Depends(get_db_ctx)
Expand All @@ -70,9 +111,7 @@ async def create_new_project(
)


@router.delete(
"/project/{uuid}", status_code=status.HTTP_200_OK, tags=["Project"]
)
@router.delete("/project/{uuid}", status_code=status.HTTP_200_OK, tags=["Project"])
async def delete_project(uuid: UUID, db_ctx: DBContext = Depends(get_db_ctx)):
projects = create_project_service(db_ctx)
try:
Expand Down
2 changes: 2 additions & 0 deletions server/src/core/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

default_system_prompt = "You are an expert research assistant."

additional_instructions = "The paper is included, if all inclusion criteria match. If the paper matches any exclusion criteria, it is excluded."

# Openrouter recommends instructing the LLM to respond in JSON format.
# Tested to be working with Fireworks.ai provided LLaMA 4 Maverick
json_instruct_prompt = """Output **ONLY JSON**. You should include **EVERY FIELD** defined in the schema - every field in the schema is required. Respond strictly in valid JSON format, using the following schema:
Expand Down
84 changes: 84 additions & 0 deletions server/src/services/token_estimation_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import json
from typing import Any, Dict, List

import tiktoken
from pydantic import BaseModel

from src.core.prompts import (
additional_instructions,
default_system_prompt,
zero_shot_task_prompt,
)
from src.schemas.llm import StructuredResponse
from src.schemas.paper import PaperRead
from src.schemas.project import Criteria
from src.tools.llm_decision_creator import create_criteria

DEFAULT_ENCODING = "o200k_base"


class TokenEstimation(BaseModel):
task_count: int
estimated_input_tokens: int
estimated_output_tokens: int
total_estimated_tokens: int


class TokenEstimationService:
def __init__(self):
self._encoder = tiktoken.get_encoding(DEFAULT_ENCODING)

async def estimate_tokens(
self,
papers: List[PaperRead],
criteria: Criteria,
system_prompt: str = default_system_prompt,
task_prompt_template: str = zero_shot_task_prompt,
additional_instructions: str = additional_instructions,
response_schema: Dict[str, Any] = StructuredResponse.model_json_schema(),
) -> TokenEstimation:

schema_str = json.dumps(response_schema)
static_text = f"{system_prompt}\n{schema_str}"
static_tokens = self._count_tokens(static_text)

criteria_text = create_criteria(
criteria.inclusion_criteria, criteria.exclusion_criteria
)
num_criteria = len(criteria.inclusion_criteria) + len(
criteria.exclusion_criteria
)

total_input = 0
total_output = 0

for paper in papers:
task_prompt = task_prompt_template.format(
paper.title,
paper.abstract,
criteria_text,
additional_instructions,
)
paper_input = static_tokens + self._count_tokens(task_prompt)
total_input += paper_input

# Overhead + Overall decision + per criteria
paper_output = 50 + 30 + (num_criteria * 15)
total_output += paper_output

buffer = 1.10

return TokenEstimation(
task_count=len(papers),
estimated_input_tokens=int(total_input * buffer),
estimated_output_tokens=int(total_output * buffer),
total_estimated_tokens=int((total_input + total_output) * buffer),
)

def _count_tokens(self, text: str) -> int:
if not text:
return 0
return len(self._encoder.encode(text))

def create_token_estimation_service() -> TokenEstimationService:
return TokenEstimationService()
9 changes: 3 additions & 6 deletions server/src/tools/llm_decision_creator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from httpx import AsyncClient

from src.core.prompts import few_shot_task_prompt, zero_shot_task_prompt
from src.core.prompts import few_shot_task_prompt, zero_shot_task_prompt, additional_instructions
from src.db.models.jobtask import JobTask
from src.schemas.job import (
FewShotPromptingConfig,
Expand All @@ -27,7 +27,7 @@ def _create_few_shot_examples(papers: list[PaperRead]):
return "\n\n".join(txt_parts)


def _create_criteria(
def create_criteria(
inclusion_criteria: list[str], exclusion_criteria: list[str]
) -> str:
criteria = "\nInclusion criteria:\n\n"
Expand All @@ -47,10 +47,7 @@ async def get_structured_response(
inc_exc_criteria: Criteria,
client: AsyncClient,
) -> StructuredResponse:
# TODO: Move to another place
additional_instructions = "The paper is included, if all inclusion criteria match. If the paper matches any exclusion criteria, it is excluded."

criteria = _create_criteria(
criteria = create_criteria(
# TODO: Fix
inc_exc_criteria["inclusion_criteria"], # type: ignore
inc_exc_criteria["exclusion_criteria"], # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions server/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading