diff --git a/pdm.lock b/pdm.lock index 6778d2c..80c9d6d 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:c63774322b7541911172368c9e2c170b3d9b5ebe35d6aacc0836d7dab3ff7ac5" +content_hash = "sha256:dbf74dae00ae8178160da6e359d91b80410eaabdb07aabb3f27270f24ad28dd6" [[metadata.targets]] requires_python = ">=3.11.1,<=3.12.8" @@ -2379,6 +2379,34 @@ files = [ {file = "psycopg2_binary-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:18c5ee682b9c6dd3696dad6e54cc7ff3a1a9020df6a5c0f861ef8bfd338c3ca0"}, ] +[[package]] +name = "pyarrow" +version = "20.0.0" +requires_python = ">=3.9" +summary = "Python library for Apache Arrow" +groups = ["default"] +files = [ + {file = "pyarrow-20.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:24ca380585444cb2a31324c546a9a56abbe87e26069189e14bdba19c86c049f0"}, + {file = "pyarrow-20.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:95b330059ddfdc591a3225f2d272123be26c8fa76e8c9ee1a77aad507361cfdb"}, + {file = "pyarrow-20.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f0fb1041267e9968c6d0d2ce3ff92e3928b243e2b6d11eeb84d9ac547308232"}, + {file = "pyarrow-20.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ff87cc837601532cc8242d2f7e09b4e02404de1b797aee747dd4ba4bd6313f"}, + {file = "pyarrow-20.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:7a3a5dcf54286e6141d5114522cf31dd67a9e7c9133d150799f30ee302a7a1ab"}, + {file = "pyarrow-20.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a6ad3e7758ecf559900261a4df985662df54fb7fdb55e8e3b3aa99b23d526b62"}, + {file = "pyarrow-20.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6bb830757103a6cb300a04610e08d9636f0cd223d32f388418ea893a3e655f1c"}, + {file = "pyarrow-20.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:96e37f0766ecb4514a899d9a3554fadda770fb57ddf42b63d80f14bc20aa7db3"}, + {file = "pyarrow-20.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:3346babb516f4b6fd790da99b98bed9708e3f02e734c84971faccb20736848dc"}, + {file = "pyarrow-20.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:75a51a5b0eef32727a247707d4755322cb970be7e935172b6a3a9f9ae98404ba"}, + {file = "pyarrow-20.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:211d5e84cecc640c7a3ab900f930aaff5cd2702177e0d562d426fb7c4f737781"}, + {file = "pyarrow-20.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ba3cf4182828be7a896cbd232aa8dd6a31bd1f9e32776cc3796c012855e1199"}, + {file = "pyarrow-20.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c3a01f313ffe27ac4126f4c2e5ea0f36a5fc6ab51f8726cf41fee4b256680bd"}, + {file = "pyarrow-20.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a2791f69ad72addd33510fec7bb14ee06c2a448e06b649e264c094c5b5f7ce28"}, + {file = "pyarrow-20.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4250e28a22302ce8692d3a0e8ec9d9dde54ec00d237cff4dfa9c1fbf79e472a8"}, + {file = "pyarrow-20.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:89e030dc58fc760e4010148e6ff164d2f44441490280ef1e97a542375e41058e"}, + {file = "pyarrow-20.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6102b4864d77102dbbb72965618e204e550135a940c2534711d5ffa787df2a5a"}, + {file = "pyarrow-20.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:96d6a0a37d9c98be08f5ed6a10831d88d52cac7b13f5287f1e0f625a0de8062b"}, + {file = "pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1"}, +] + [[package]] name = "pyasn1" version = "0.6.1" diff --git a/pyproject.toml b/pyproject.toml index 84b7377..e782fa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "torch==2.6.0+cpu", "road-core @ git+https://github.com/road-core/service.git", "matplotlib>=3.10.1", + "pyarrow>=20.0.0", ] requires-python = ">=3.11.1,<=3.12.8" readme = "README.md" diff --git a/src/road_core_eval/evaluate.py b/src/road_core_eval/evaluate.py index 0367947..48e4ffa 100644 --- a/src/road_core_eval/evaluate.py +++ b/src/road_core_eval/evaluate.py @@ -1,6 +1,6 @@ """Driver for evaluation.""" -import argparse +from argparse import Namespace, ArgumentParser import os from httpx import Client from road_core_eval.response_evaluation import ResponseEvaluation @@ -11,9 +11,10 @@ ) -def main(): - """Evaluate response.""" - parser = argparse.ArgumentParser(description="Response validation module.") +def parse_args() -> Namespace: + """Parse CLI arguments for response evaluation tool.""" + + parser = ArgumentParser(description="Response validation module.") parser.add_argument( "--eval_provider_model_id", nargs="+", @@ -95,7 +96,13 @@ def main(): type=str, help="Path to text file with API token (applicable when deployed on cluster)", ) - args = parser.parse_args() + return parser.parse_args() + + +def main(): + """Evaluate response.""" + args = parse_args() + client = Client(base_url=args.eval_api_url, verify=False) # noqa: S501 if "localhost" not in args.eval_api_url: diff --git a/tests/test_response_evaluation.py b/tests/test_response_evaluation.py new file mode 100644 index 0000000..7ea779f --- /dev/null +++ b/tests/test_response_evaluation.py @@ -0,0 +1,35 @@ +"""Tests for response_evaluation module""" + +from argparse import Namespace +from unittest.mock import patch + +from httpx import Client + +from road_core_eval.response_evaluation import ResponseEvaluation + + +def test_response_evaluation_init(tmpdir): + """Test initialization of ResponseEvaluation object with default + arguments from road_core_eval.evaluate module. + """ + out_dir = tmpdir.mkdir("out_dir") + args = Namespace( + eval_provider_model_id=["watsonx+ibm/granite-3-8b-instruct"], + judge_provider="ollama", + judge_model="llama3.1:latest", + eval_data_src="eval_data/question_answer_pair.json", + eval_out_dir=out_dir, + eval_query_ids=None, + eval_scenario="with_rag", + qna_pool_file=None, + eval_type="model", + eval_metrics=["cos_score"], + eval_modes=["ols"], + eval_api_url="http://localhost:8080", + eval_api_token_file="ols_api_key.txt", + ) + + client = Client(base_url=args.eval_api_url, verify=False) + # Mock HF class to prevent model download + with patch("llama_index.embeddings.huggingface.HuggingFaceEmbedding"): + ResponseEvaluation(args, client)