diff --git a/src/aind_data_access_api/helpers/data_schema.py b/src/aind_data_access_api/helpers/data_schema.py index a5fc244..80bcaef 100644 --- a/src/aind_data_access_api/helpers/data_schema.py +++ b/src/aind_data_access_api/helpers/data_schema.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import pandas as pd -from aind_data_schema.core.quality_control import QualityControl +from aind_data_schema.core.quality_control import QCEvaluation, QualityControl from aind_data_access_api.document_db import MetadataDbClient from aind_data_access_api.helpers.docdb import ( @@ -182,3 +182,38 @@ def get_quality_control_value_df( data.append(qc_metrics_flat) return pd.DataFrame(data) + + +def add_qc_evaluations_to_docdb( + client: MetadataDbClient, + data_asset_id: str, + evaluations: Union[QCEvaluation, List[QCEvaluation]], +): + """Serialize QCEvaluation object(s) and add them to DocDB. + + Parameters + ---------- + client : MetadataDbClient + A connected DocumentDB client. + evaluations : QCEvaluation or list[QCEvaluation] + One or more QCEvaluation objects from aind-data-schema. + + Returns + ------- + dict or list[dict] + The response(s) from MetadataDbClient.add_qc_evaluation. + """ + + if isinstance(evaluations, QCEvaluation): + evaluations = [evaluations] + + serialized = [ + e.model_dump(mode="json", exclude_none=True) for e in evaluations + ] + + qc_contents = {"qc_evaluation": serialized} + response = client.add_qc_evaluation( + data_asset_id=data_asset_id, qc_contents=qc_contents + ) + + return response diff --git a/tests/helpers/test_data_schema.py b/tests/helpers/test_data_schema.py index 90edf83..5d1b324 100644 --- a/tests/helpers/test_data_schema.py +++ b/tests/helpers/test_data_schema.py @@ -17,8 +17,10 @@ Status, ) from aind_data_schema_models.modalities import Modality +from requests import HTTPError from aind_data_access_api.helpers.data_schema import ( + add_qc_evaluations_to_docdb, get_quality_control_by_id, get_quality_control_by_name, get_quality_control_by_names, @@ -313,6 +315,158 @@ def test_get_quality_control_by_names_no_records(self): projection={"quality_control": 1}, ) + def test_serialize_qc_single_success(self): + """Test add_qc_evaluations_to_docdb succeeds for a single + QCEvaluation.""" + mock_client = MagicMock() + # mock a response that add_qc_evaluation would return + mock_client.add_qc_evaluation.return_value = {"acknowledged": True} + + modality = { + "name": "Extracellular electrophysiology", + "abbreviation": "ecephys", + } + qc_eval = QCEvaluation( + modality=modality, + stage="Raw data", + name="Test QC Single", + metrics=[ + QCMetric( + name="Metric 1", + value="Pass", + status_history=[ + QCStatus( + evaluator="Automated test", + status=Status.PASS, + timestamp=datetime(2025, 10, 6), + ) + ], + ) + ], + notes="Single test", + ) + + response = add_qc_evaluations_to_docdb( + mock_client, "valid_id", qc_eval + ) + + self.assertIsInstance(response, dict) + self.assertTrue(response["acknowledged"]) + mock_client.add_qc_evaluation.assert_called_once_with( + data_asset_id="valid_id", + qc_contents={ + "qc_evaluation": [ + qc_eval.model_dump(mode="json", exclude_none=True) + ] + }, + ) + + def test_serialize_qc_list_success(self): + """Test add_qc_evaluations_to_docdb succeeds for a list of + QCEvaluations.""" + mock_client = MagicMock() + mock_client.add_qc_evaluation.return_value = {"acknowledged": True} + + modality = { + "name": "Extracellular electrophysiology", + "abbreviation": "ecephys", + } + qc_eval1 = QCEvaluation( + modality=modality, + stage="Raw data", + name="Test QC 1", + metrics=[ + QCMetric( + name="Metric 1", + value="Pass", + status_history=[ + QCStatus( + evaluator="Automated test", + status=Status.PASS, + timestamp=datetime(2025, 10, 6), + ) + ], + ) + ], + notes="First test", + ) + + qc_eval2 = QCEvaluation( + modality=modality, + stage="Raw data", + name="Test QC 2", + metrics=[ + QCMetric( + name="Metric 2", + value="Fail", + status_history=[ + QCStatus( + evaluator="Automated test", + status=Status.FAIL, + timestamp=datetime(2025, 10, 6), + ) + ], + ) + ], + notes="Second test", + ) + + response = add_qc_evaluations_to_docdb( + mock_client, "valid_id", [qc_eval1, qc_eval2] + ) + + self.assertIsInstance(response, dict) + self.assertTrue(response["acknowledged"]) + mock_client.add_qc_evaluation.assert_called_once_with( + data_asset_id="valid_id", + qc_contents={ + "qc_evaluation": [ + qc_eval1.model_dump(mode="json", exclude_none=True), + qc_eval2.model_dump(mode="json", exclude_none=True), + ] + }, + ) + + def test_serialize_qc_failure(self): + """Test error when data_asset_id is invalid.""" + mock_client = MagicMock() + mock_client.add_qc_evaluation.side_effect = HTTPError( + "404 Client Error" + ) + + modality = { + "name": "Extracellular electrophysiology", + "abbreviation": "ecephys", + } + qc_eval = QCEvaluation( + modality=modality, + stage="Raw data", + name="Test QC Invalid", + metrics=[ + QCMetric( + name="Metric 1", + value="Pass", + status_history=[ + QCStatus( + evaluator="Automated test", + status=Status.PASS, + timestamp=datetime(2025, 10, 6), + ) + ], + ) + ], + notes="Invalid test", + ) + + with self.assertRaises(HTTPError) as e: + add_qc_evaluations_to_docdb( + client=mock_client, + data_asset_id="bad_id", + evaluations=qc_eval, + ) + + self.assertIn("404 Client Error", str(e.exception)) + if __name__ == "__main__": unittest.main()