Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion src/aind_data_transfer/jobs/basic_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,20 @@
from enum import Enum
from importlib import import_module
from pathlib import Path
from typing import Dict, List

import boto3
from aind_codeocean_api.codeocean import CodeOceanClient
from aind_data_schema.data_description import Modality
from aind_data_schema.base import AindCoreModel
from aind_data_schema.data_description import (
ExperimentType,
Modality,
RawDataDescription,
)
from aind_data_schema.ephys.ephys_rig import EphysRig
from aind_data_schema.ephys.ephys_session import EphysSession
from aind_data_schema.imaging.acquisition import Acquisition
from aind_data_schema.imaging.instrument import Instrument

from aind_data_transfer import __version__
from aind_data_transfer.config_loader.base_config import BasicUploadJobConfigs
Expand All @@ -41,6 +51,21 @@ class JobTypes(Enum):
class BasicJob:
"""Class that defines a basic upload job."""

# List of experiment_types to verify all metadata files are defined
# Subject, Procedures, and Processing will be generated if not present.
_METADATA_COMPLETENESS_CHECK: Dict[ExperimentType, List[AindCoreModel]] = {
ExperimentType.ECEPHYS: [
RawDataDescription,
EphysRig,
EphysSession,
],
ExperimentType.SMARTSPIM: [
RawDataDescription,
Instrument,
Acquisition,
],
}

def __init__(self, job_configs: BasicUploadJobConfigs):
"""Init with job_configs"""
self.job_configs = job_configs
Expand All @@ -51,6 +76,34 @@ def __init__(self, job_configs: BasicUploadJobConfigs):
)
self._instance_logger.setLevel(job_configs.log_level)

def _metadata_completeness_check(self) -> bool:
"""For the experiment types listed in the _METADATA_COMPLETENESS_CHECK,
check if the required files are in the directory before the data is
allowed to be compressed and uploaded."""

def check_dir(path_to_check) -> bool:
"""Checks if required modality files are present"""
all_files = os.listdir(path_to_check)
json_files = [m for m in all_files if str(m).endswith(".json")]
required_files = [m.default_filename() for m in expected_metadata]
if len(set(required_files) - set(json_files)) > 0:
raise Exception(
f"All of {required_files} required for upload!"
)
else:
return True

exp_type = self.job_configs.experiment_type
expected_metadata = self._METADATA_COMPLETENESS_CHECK.get(exp_type)
if expected_metadata is None:
check = True
elif self.job_configs.metadata_dir is not None:
check = check_dir(self.job_configs.metadata_dir)
else:
# If no metadata_dir defined, check the parent of the first source
check = check_dir(self.job_configs.modalities[0].source.parent)
return check

def _test_upload(self, temp_dir: Path):
"""Run upload command on empty directory to see if user has permissions
and aws cli installed.
Expand Down Expand Up @@ -335,6 +388,7 @@ def run_job(self):
uploading."""
process_start_time = datetime.now(timezone.utc)
self._check_if_s3_location_exists()
self._metadata_completeness_check()
with tempfile.TemporaryDirectory(
dir=self.job_configs.temp_directory
) as td:
Expand Down
58 changes: 58 additions & 0 deletions tests/test_basic_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from unittest.mock import MagicMock, call, patch

from aind_data_schema.data_description import ExperimentType
from requests import Response

from aind_data_transfer import __version__
Expand Down Expand Up @@ -56,6 +57,57 @@ class TestBasicJob(unittest.TestCase):
"DRY_RUN": "true",
}

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("os.listdir")
def test_metadata_completeness_check(self, mock_list_dir: MagicMock):

mock_list_dir.side_effect = (
["data_description.json", "ephys_rig.json", "ephys_session.json"],
["data_description.json", "ephys_rig.json", "ephys_session.json"],
["data_description.json", "instrument.json", "acquisition.json"],
["data_description.json", "instrument.json", "acquisition.json"],
["data_description.json"], # Use this to assert error raised
)
basic_job_configs = BasicUploadJobConfigs()
basic_job = BasicJob(job_configs=basic_job_configs)
ephys_job_configs1 = BasicUploadJobConfigs(
experiment_type=ExperimentType.ECEPHYS
)
ephys_job1 = BasicJob(job_configs=ephys_job_configs1)
ephys_job_configs2 = BasicUploadJobConfigs(
experiment_type=ExperimentType.ECEPHYS, metadata_dir=METADATA_DIR
)
ephys_job2 = BasicJob(job_configs=ephys_job_configs2)
smartspim_job_configs1 = BasicUploadJobConfigs(
experiment_type=ExperimentType.SMARTSPIM
)
smartspim_job1 = BasicJob(job_configs=smartspim_job_configs1)
smartspim_job_configs2 = BasicUploadJobConfigs(
experiment_type=ExperimentType.SMARTSPIM, metadata_dir=METADATA_DIR
)
smartspim_job2 = BasicJob(job_configs=smartspim_job_configs2)
smartspim_job_configs3 = BasicUploadJobConfigs(
experiment_type=ExperimentType.SMARTSPIM, metadata_dir=METADATA_DIR
)
smartspim_job3 = BasicJob(job_configs=smartspim_job_configs3)
check1 = basic_job._metadata_completeness_check()
check2 = ephys_job1._metadata_completeness_check()
check3 = ephys_job2._metadata_completeness_check()
check4 = smartspim_job1._metadata_completeness_check()
check5 = smartspim_job2._metadata_completeness_check()
self.assertTrue(check1)
self.assertTrue(check2)
self.assertTrue(check3)
self.assertTrue(check4)
self.assertTrue(check5)
with self.assertRaises(Exception) as e:
smartspim_job3._metadata_completeness_check()
expected_error_message = (
"Exception(\"All of ['data_description.json', 'instrument.json',"
" 'acquisition.json'] required for upload!\")"
)
self.assertEqual(expected_error_message, repr(e.exception))

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("tempfile.TemporaryDirectory")
@patch("aind_data_transfer.jobs.basic_job.upload_to_s3")
Expand Down Expand Up @@ -434,8 +486,13 @@ def test_trigger_custom_codeocean_capsule(
)
@patch("aind_data_transfer.jobs.basic_job.datetime")
@patch("aind_data_transfer.jobs.basic_job.BasicJob._test_upload")
@patch(
"aind_data_transfer.jobs.basic_job.BasicJob."
"_metadata_completeness_check"
)
def test_run_job(
self,
mock_metadata_completeness_check: MagicMock,
mock_test_upload: MagicMock,
mock_datetime: MagicMock,
mock_trigger_pipeline: MagicMock,
Expand Down Expand Up @@ -478,6 +535,7 @@ def test_run_job(
temp_dir=(Path("some_dir") / "tmp")
)
mock_trigger_pipeline.assert_called_once()
mock_metadata_completeness_check.assert_called_once()

self.assertEqual(1, 1)

Expand Down