From 59390bfa8f318ac6aa7561d4a0f8845d0a15f325 Mon Sep 17 00:00:00 2001 From: Anastasia Bratulin <109570522+anastasiabratulin@users.noreply.github.com> Date: Wed, 13 May 2026 13:58:53 -0400 Subject: [PATCH 1/2] added polars to dependencies, changed nb to reflect input change to public tsv --- ...nnotation_prognostic_sample_analysis.ipynb | 597 ++++++++++++++++++ pyproject.toml | 3 +- 2 files changed, 599 insertions(+), 1 deletion(-) create mode 100644 notebooks/llm_annotation_prognostic_sample_analysis.ipynb diff --git a/notebooks/llm_annotation_prognostic_sample_analysis.ipynb b/notebooks/llm_annotation_prognostic_sample_analysis.ipynb new file mode 100644 index 0000000..b90d538 --- /dev/null +++ b/notebooks/llm_annotation_prognostic_sample_analysis.ipynb @@ -0,0 +1,597 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6310464c", + "metadata": {}, + "source": [ + "# Analyze results from LLM Annotation of Applied Prognostic Evidence Line Comments" + ] + }, + { + "cell_type": "markdown", + "id": "9e679102", + "metadata": {}, + "source": [ + "LLM metadata\n", + "- model_id: us.anthropic.claude-sonnet-4-6\n", + "- temperature: 0.0" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "26c4a876", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from pathlib import Path\n", + "\n", + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4b5d25d7", + "metadata": {}, + "outputs": [], + "source": [ + "# This file includes LLM and human annotations for analysis\n", + "SAMPLE_PATH = Path(\n", + " \"public-sample_prognostic_annotated_ingested_vafs_2026-04-28-run2-temp0.0.tsv\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "36992c8d", + "metadata": {}, + "source": [ + "## Helper functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b30a9bfe", + "metadata": {}, + "outputs": [], + "source": [ + "def confusion_matrix_with_collapse(\n", + " df: pl.DataFrame,\n", + " gt_col: str,\n", + " pred_col: str,\n", + " collapse: bool = False,\n", + ") -> pl.DataFrame:\n", + " \"\"\"Compute a confusion matrix for:\n", + " ['poor outcome', 'better outcome', 'unclear', 'null'],\n", + " with optional merging of 'unclear' and 'null' into 'indeterminate'.\n", + "\n", + " :param df: Input DataFrame with evidence lines, associated comments, curator annotations, and LLM annotations\n", + " :param gt_col: ground truth/human curator values column name (e.g., 'Curator Annotation')\n", + " :param pred_col: predicted column name (e.g., 'LLM Annotation')\n", + " :param collapse: whether to collapse 'unclear' and 'null' into 'indeterminate'\n", + " :return: confusion matrix as a DataFrame\n", + " \"\"\"\n", + " # Normalize ground truth\n", + " gt = (\n", + " pl.col(gt_col)\n", + " .fill_null(\"null\")\n", + " .cast(pl.Utf8)\n", + " .str.strip_chars()\n", + " .str.to_lowercase()\n", + " )\n", + " # Normalize predictions\n", + " pred = (\n", + " pl.col(pred_col)\n", + " .fill_null(\"null\")\n", + " .cast(pl.Utf8)\n", + " .str.strip_chars()\n", + " .str.to_lowercase()\n", + " )\n", + "\n", + " working_df = df.with_columns(\n", + " [\n", + " gt.alias(\"gt\"),\n", + " pred.alias(\"pred\"),\n", + " ]\n", + " )\n", + "\n", + " if collapse:\n", + " collapse_map = {\n", + " \"unclear\": \"indeterminate\",\n", + " \"null\": \"indeterminate\",\n", + " }\n", + "\n", + " working_df = working_df.with_columns(\n", + " [\n", + " pl.col(\"gt\").replace(collapse_map),\n", + " pl.col(\"pred\").replace(collapse_map),\n", + " ]\n", + " )\n", + "\n", + " classes = [\n", + " \"poor outcome\",\n", + " \"better outcome\",\n", + " \"indeterminate\",\n", + " ]\n", + "\n", + " else:\n", + " classes = [\n", + " \"poor outcome\",\n", + " \"better outcome\",\n", + " \"unclear\",\n", + " \"null\",\n", + " ]\n", + "\n", + " cm = (\n", + " working_df.group_by([\"gt\", \"pred\"])\n", + " .len()\n", + " .pivot(\n", + " values=\"len\",\n", + " index=\"gt\",\n", + " on=\"pred\",\n", + " )\n", + " .fill_null(0)\n", + " )\n", + "\n", + " # Ensure all rows exist\n", + " missing_rows = [cls for cls in classes if cls not in cm[\"gt\"].to_list()]\n", + "\n", + " if missing_rows:\n", + " cm = cm.with_columns(pl.col(\"gt\").cast(pl.Utf8))\n", + "\n", + " cm = pl.concat(\n", + " [\n", + " cm,\n", + " pl.DataFrame(\n", + " {\n", + " \"gt\": pl.Series(missing_rows, dtype=pl.Utf8),\n", + " **{\n", + " c: pl.Series(\n", + " [0] * len(missing_rows),\n", + " dtype=cm.schema.get(c, pl.UInt32),\n", + " )\n", + " for c in classes\n", + " },\n", + " }\n", + " ),\n", + " ],\n", + " how=\"diagonal\",\n", + " )\n", + "\n", + " # Ensure all columns exist\n", + " for cls in classes:\n", + " if cls not in cm.columns:\n", + " cm = cm.with_columns(pl.lit(0).alias(cls))\n", + "\n", + " # Reorder\n", + " cm = cm.select([\"gt\", *classes])\n", + "\n", + " # Sort rows\n", + " cm = (\n", + " cm.with_columns(\n", + " pl.col(\"gt\").replace({v: i for i, v in enumerate(classes)}).alias(\"_order\")\n", + " )\n", + " .sort(\"_order\")\n", + " .drop(\"_order\")\n", + " )\n", + "\n", + " return cm.rename({\"gt\": \"Ground Truth\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3280e366", + "metadata": {}, + "outputs": [], + "source": [ + "def create_analysis_summary(cm: pl.DataFrame) -> pl.DataFrame:\n", + " \"\"\"Compute per-class recall-style summary from a confusion matrix.\n", + "\n", + " :param cm: Confusion matrix (square DataFrame)\n", + " :return: Summary DataFrame per class\n", + " \"\"\"\n", + " rows = []\n", + "\n", + " classes = cm.columns[1:]\n", + "\n", + " for cls in classes:\n", + " row = cm.filter(pl.col(\"Ground Truth\") == cls)\n", + "\n", + " if row.height == 0:\n", + " numerator = 0\n", + " denominator = 0\n", + " else:\n", + " numerator = row.select(pl.col(cls)).item()\n", + "\n", + " denominator = row.select(pl.sum_horizontal(classes)).item()\n", + "\n", + " rows.append(\n", + " {\n", + " \"consensus_w_curator\": cls,\n", + " \"numerator\": numerator,\n", + " \"denominator\": denominator,\n", + " \"percentage\": (\n", + " (numerator or 0) / denominator * 100\n", + " if denominator is not None and denominator > 0\n", + " else 0.0\n", + " ),\n", + " }\n", + " )\n", + "\n", + " return pl.DataFrame(rows)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9a71a5a5", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_overall_accuracy(cm: pl.DataFrame) -> float:\n", + " \"\"\"Compute overall accuracy from a confusion matrix.\n", + "\n", + " :param cm: Confusion matrix DataFrame (square matrix)\n", + " :return: Accuracy as a float between 0 and 1\n", + " \"\"\"\n", + " classes = cm.columns[1:]\n", + "\n", + " total = cm.select(pl.sum_horizontal(classes)).to_series().sum()\n", + "\n", + " correct = 0\n", + "\n", + " for cls in classes:\n", + " row = cm.filter(pl.col(\"Ground Truth\") == cls)\n", + "\n", + " if row.height > 0:\n", + " correct += row.select(pl.col(cls)).item() or 0\n", + "\n", + " return correct / total if total > 0 else math.nan" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "62103093", + "metadata": {}, + "outputs": [], + "source": [ + "def analyze_results(\n", + " df: pl.DataFrame,\n", + " collapse: bool = True,\n", + ") -> tuple[pl.DataFrame, float, pl.DataFrame]:\n", + " \"\"\"Evaluate LLM predictions against ground truth using a confusion matrix.\n", + "\n", + " Computes:\n", + " - confusion matrix (optionally collapsed classes)\n", + " - per-class performance summary (recall-style)\n", + " - overall accuracy\n", + "\n", + " :param df: Input DataFrame with ground truth and predictions\n", + " :param collapse: If True, merges 'unclear' and 'null' into 'indeterminate'\n", + " :return:\n", + " - match_analysis_summary: per-class recall summary\n", + " - accuracy: overall accuracy\n", + " - cm: confusion matrix\n", + " \"\"\"\n", + " analysis_df = df.clone()\n", + "\n", + " cm = confusion_matrix_with_collapse(\n", + " analysis_df,\n", + " \"Curator Annotation\",\n", + " \"LLM Annotation\",\n", + " collapse=collapse,\n", + " )\n", + "\n", + " match_analysis_summary = create_analysis_summary(cm)\n", + "\n", + " accuracy = compute_overall_accuracy(cm)\n", + "\n", + " return match_analysis_summary, accuracy, cm" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5e5d6c59", + "metadata": {}, + "outputs": [], + "source": [ + "def process_results(\n", + " collapse: bool,\n", + ") -> tuple[dict, dict, list]:\n", + " \"\"\"Run analysis on a stored model output\n", + "\n", + " :param collapse: Whether to collapse categories during analysis\n", + " :return: Tuple of all results (summaries, cm, accuracies)\n", + " \"\"\"\n", + " _df = pl.read_csv(SAMPLE_PATH, separator=\"\\t\")\n", + "\n", + " summary, accuracy, cm = analyze_results(\n", + " _df,\n", + " collapse=collapse,\n", + " )\n", + "\n", + " return (\n", + " summary,\n", + " cm,\n", + " accuracy,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "00d0f05f", + "metadata": {}, + "source": [ + "## Analysis of Runs" + ] + }, + { + "cell_type": "markdown", + "id": "aa3f6b27", + "metadata": {}, + "source": [ + "### Main paper" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1f83ec4c", + "metadata": {}, + "outputs": [], + "source": [ + "summary, cm, accuracy = process_results(collapse=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "56442685", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (3, 4)
Ground Truthpoor outcomebetter outcomeindeterminate
stru32u32u32
"poor outcome"4912
"better outcome"0500
"indeterminate"25726
" + ], + "text/plain": [ + "shape: (3, 4)\n", + "┌────────────────┬──────────────┬────────────────┬───────────────┐\n", + "│ Ground Truth ┆ poor outcome ┆ better outcome ┆ indeterminate │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ u32 ┆ u32 ┆ u32 │\n", + "╞════════════════╪══════════════╪════════════════╪═══════════════╡\n", + "│ poor outcome ┆ 49 ┆ 1 ┆ 2 │\n", + "│ better outcome ┆ 0 ┆ 50 ┆ 0 │\n", + "│ indeterminate ┆ 25 ┆ 7 ┆ 26 │\n", + "└────────────────┴──────────────┴────────────────┴───────────────┘" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cm" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c0f933dd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (3, 4)
consensus_w_curatornumeratordenominatorpercentage
stri64i64f64
"poor outcome"495294.230769
"better outcome"5050100.0
"indeterminate"265844.827586
" + ], + "text/plain": [ + "shape: (3, 4)\n", + "┌─────────────────────┬───────────┬─────────────┬────────────┐\n", + "│ consensus_w_curator ┆ numerator ┆ denominator ┆ percentage │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ i64 ┆ i64 ┆ f64 │\n", + "╞═════════════════════╪═══════════╪═════════════╪════════════╡\n", + "│ poor outcome ┆ 49 ┆ 52 ┆ 94.230769 │\n", + "│ better outcome ┆ 50 ┆ 50 ┆ 100.0 │\n", + "│ indeterminate ┆ 26 ┆ 58 ┆ 44.827586 │\n", + "└─────────────────────┴───────────┴─────────────┴────────────┘" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "f6e0016a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.78125" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "accuracy" + ] + }, + { + "cell_type": "markdown", + "id": "34f6132d", + "metadata": {}, + "source": [ + "### Supplemental" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cd844131", + "metadata": {}, + "outputs": [], + "source": [ + "sup_summary, sup_cm, sup_accuracy = process_results(collapse=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "dd43f4e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (4, 5)
Ground Truthpoor outcomebetter outcomeunclearnull
stru32u32u32u32
"poor outcome"49111
"better outcome"05000
"unclear"8550
"null"172147
" + ], + "text/plain": [ + "shape: (4, 5)\n", + "┌────────────────┬──────────────┬────────────────┬─────────┬──────┐\n", + "│ Ground Truth ┆ poor outcome ┆ better outcome ┆ unclear ┆ null │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ u32 ┆ u32 ┆ u32 ┆ u32 │\n", + "╞════════════════╪══════════════╪════════════════╪═════════╪══════╡\n", + "│ poor outcome ┆ 49 ┆ 1 ┆ 1 ┆ 1 │\n", + "│ better outcome ┆ 0 ┆ 50 ┆ 0 ┆ 0 │\n", + "│ unclear ┆ 8 ┆ 5 ┆ 5 ┆ 0 │\n", + "│ null ┆ 17 ┆ 2 ┆ 14 ┆ 7 │\n", + "└────────────────┴──────────────┴────────────────┴─────────┴──────┘" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sup_cm" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "3d841074", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (4, 4)
consensus_w_curatornumeratordenominatorpercentage
stri64i64f64
"poor outcome"495294.230769
"better outcome"5050100.0
"unclear"51827.777778
"null"74017.5
" + ], + "text/plain": [ + "shape: (4, 4)\n", + "┌─────────────────────┬───────────┬─────────────┬────────────┐\n", + "│ consensus_w_curator ┆ numerator ┆ denominator ┆ percentage │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ i64 ┆ i64 ┆ f64 │\n", + "╞═════════════════════╪═══════════╪═════════════╪════════════╡\n", + "│ poor outcome ┆ 49 ┆ 52 ┆ 94.230769 │\n", + "│ better outcome ┆ 50 ┆ 50 ┆ 100.0 │\n", + "│ unclear ┆ 5 ┆ 18 ┆ 27.777778 │\n", + "│ null ┆ 7 ┆ 40 ┆ 17.5 │\n", + "└─────────────────────┴───────────┴─────────────┴────────────┘" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sup_summary" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4d5a5ef4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.69375" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sup_accuracy" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv (3.13.5)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index ff92af3..cba6895 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ description = "Transform semi-structured somatic cancer variant classification k license = "MIT" license-files = ["LICENSE"] dependencies = [ - "ipykernel" + "ipykernel", + "polars" ] [project.optional-dependencies] From 7e4faef326b61ab2dbe4dbafc10b2fbbd1c1a359 Mon Sep 17 00:00:00 2001 From: Kori Kuzma Date: Fri, 15 May 2026 10:18:10 -0400 Subject: [PATCH 2/2] wip: initial work for using wags_llm --- ...nnotation_prognostic_sample_analysis.ipynb | 182 +++++++++++++++++- pyproject.toml | 3 +- 2 files changed, 182 insertions(+), 3 deletions(-) diff --git a/notebooks/llm_annotation_prognostic_sample_analysis.ipynb b/notebooks/llm_annotation_prognostic_sample_analysis.ipynb index b90d538..824c199 100644 --- a/notebooks/llm_annotation_prognostic_sample_analysis.ipynb +++ b/notebooks/llm_annotation_prognostic_sample_analysis.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "26c4a876", "metadata": {}, "outputs": [], @@ -44,6 +44,184 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb6fad33", + "metadata": {}, + "outputs": [], + "source": [ + "from collections.abc import Mapping\n", + "from typing import Any\n", + "\n", + "from wags_llm.prompts import BasePromptTemplate\n", + "\n", + "PROMPT_NAME = \"evidence_direction_annotation:prognostic\"\n", + "PROMPT_VERSION = \"v1\"\n", + "\n", + "\n", + "class PrognosticDirectionPromptV1(BasePromptTemplate):\n", + " \"\"\"Version 1 prompt for interpreting prognostic evidence direction from comments.\"\"\"\n", + "\n", + " version = PROMPT_VERSION\n", + " name = PROMPT_NAME\n", + "\n", + " def build_system_prompt(self) -> str:\n", + " \"\"\"Build the system prompt for extracting prognostic evidence direction from a\n", + " VAF evidence description and comment.\n", + "\n", + " :returns: System prompt text.\n", + " \"\"\"\n", + " return (\n", + " \"Return ONLY a valid JSON object (no markdown code fences).\\n\\n\"\n", + " \"Schema:\\n\"\n", + " \"evidence_line_type_direction: better outcome | poor outcome | unclear | null\\n\"\n", + " \"Task:\\n\"\n", + " \"Given an evidence description and comments, determine the clinical impact direction for prognosis for ClinVar AMP/ASCO/CAP assertionTypeForClinicalImpact submission. Comments may include identifiers, analyst, and/or director (in order; may be missing).\\n\\n\"\n", + " \"Rules:\\n\"\n", + " \"- Ignore identifiers\\n\"\n", + " \"- Use the last comment; if minimal, use the nearest preceding substantive comment(s). These are the governing comment(s)\\n\"\n", + " \"- Use description for prognostic applicability; the governing comment(s) may confirm or override it\\n\"\n", + " \"- If the governing comment(s) indicate non-applicability, exclusion, or insufficient support, return null\\n\"\n", + " \"- unclear = prognosis mentioned but direction is ambiguous or uncertain\\n\"\n", + " \"- null = prognosis does not apply or is not mentioned\\n\"\n", + " )\n", + "\n", + " def build_user_prompt(\n", + " self,\n", + " payload: Mapping[str, Any],\n", + " ) -> str:\n", + " \"\"\"Build the user prompt for a single comment.\n", + "\n", + " :param payload: Evidence direction and comments within VAF evidence line\n", + " :returns: User prompt text\n", + " \"\"\"\n", + " return (\n", + " f\"Evidence Description: {payload['description']}\\n\"\n", + " f\"Comments: {payload['comments']}\\n\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a280b31c", + "metadata": {}, + "outputs": [], + "source": [ + "from enum import StrEnum\n", + "\n", + "from pydantic import BaseModel, ConfigDict, field_validator\n", + "\n", + "\n", + "class PrognosticDirection(StrEnum):\n", + " \"\"\"Define directionality for prognostic evidence\"\"\"\n", + "\n", + " BETTER = \"better outcome\"\n", + " POOR = \"poor outcome\"\n", + " UNCLEAR = \"unclear\"\n", + "\n", + "\n", + "class PrognosticEvidenceDirectionResult(BaseModel):\n", + " \"\"\"Model for LLM and human curator Result for extracting prognostic direction from\n", + " VAF sheet comments\n", + " \"\"\"\n", + "\n", + " model_config = ConfigDict(extra=\"forbid\", use_enum_values=True)\n", + "\n", + " evidence_line_type_direction: PrognosticDirection | None\n", + "\n", + " @field_validator(\"evidence_line_type_direction\", mode=\"before\")\n", + " @classmethod\n", + " def handle_null(cls, value: str): # noqa: ANN206\n", + " \"\"\"Convert null strings to `None`\"\"\"\n", + " if value == \"null\":\n", + " return None\n", + " return value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "031abd25", + "metadata": {}, + "outputs": [], + "source": [ + "from wags_llm.cache import InMemoryCache\n", + "from wags_llm.client import BedrockClaudeJsonClient\n", + "from wags_llm.prompts import build_empty_registry\n", + "from wags_llm.services import StructuredTaskRunner\n", + "\n", + "MODEL_ID = \"us.anthropic.claude-sonnet-4-6\"\n", + "REGION_NAME = \"us-east-1\"\n", + "PROFILE_NAME = \"dev-account\"\n", + "MAX_TOKENS = 150\n", + "TEMPERATURE = 0.0\n", + "\n", + "\n", + "def build_llm_task_runner(\n", + " model_id: str,\n", + " region_name: str,\n", + " profile_name: str,\n", + " max_tokens: int,\n", + " temperature: float,\n", + "):\n", + " \"\"\"Build LLM evidence direction annotator\n", + "\n", + " :param model_id: Bedrock model identifier.\n", + " :param region_name: AWS region for the Bedrock runtime client.\n", + " :param profile_name: AWS profile name.\n", + " :param max_tokens: Maximum number of tokens to request from the model.\n", + " :param temperature: Sampling temperature.\n", + " :return: TODO:\n", + " \"\"\"\n", + " registry = build_empty_registry()\n", + " registry.register(PrognosticDirectionPromptV1())\n", + " llm_client = BedrockClaudeJsonClient(\n", + " model_id=model_id,\n", + " region_name=region_name,\n", + " profile_name=profile_name,\n", + " max_tokens=max_tokens,\n", + " temperature=temperature,\n", + " )\n", + " cache = InMemoryCache()\n", + " return StructuredTaskRunner(\n", + " client=llm_client, prompt_registry=registry, cache=cache\n", + " )\n", + "\n", + "\n", + "task_runner = build_llm_task_runner(\n", + " MODEL_ID, REGION_NAME, PROFILE_NAME, MAX_TOKENS, TEMPERATURE\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31a648bd", + "metadata": {}, + "outputs": [], + "source": [ + "# ANASTASIA TO ITERATE OVER THE TSV PROVIDING THE DESCRIPTION AND COMMENT FOR EACH ROW\n", + "response_model = PrognosticEvidenceDirectionResult\n", + "try:\n", + " task_result = task_runner.execute(\n", + " prompt_name=PROMPT_NAME,\n", + " prompt_version=PROMPT_VERSION,\n", + " payload={\n", + " \"description\": description,\n", + " \"comments\": comments,\n", + " },\n", + " response_model=response_model,\n", + " )\n", + "except Exception as e:\n", + " annotation = None\n", + " error_msg = str(e)\n", + "else:\n", + " annotation = response_model.model_validate(task_result)\n", + " error_msg = None" + ] + }, { "cell_type": "markdown", "id": "36992c8d", @@ -575,7 +753,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv (3.13.5)", + "display_name": "mci-knowledge-pilot (3.13.5)", "language": "python", "name": "python3" }, diff --git a/pyproject.toml b/pyproject.toml index cba6895..a1a6734 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ license = "MIT" license-files = ["LICENSE"] dependencies = [ "ipykernel", - "polars" + "polars", + "wags_llm" ] [project.optional-dependencies]