diff --git a/README.md b/README.md index a9c130e..69f8ef8 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,8 @@ DVUploader provides several environment variables that allow you to control retr - `DVUPLOADER_MIN_RETRY_TIME`: Minimum wait time between retries in seconds (default: 1) - `DVUPLOADER_RETRY_MULTIPLIER`: Multiplier for exponential backoff (default: 0.1) - `DVUPLOADER_MAX_PKG_SIZE`: Maximum package size in bytes (default: 2GB) +- `DVUPLOADER_LOCK_WAIT_TIME`: Time to wait between checks for dataset lock (default: 10 seconds) +- `DVUPLOADER_LOCK_TIMEOUT`: Timeout for dataset lock check in seconds (default: 300 seconds) **Setting via environment:** ```bash @@ -135,6 +137,8 @@ export DVUPLOADER_MAX_RETRY_TIME=300 export DVUPLOADER_MIN_RETRY_TIME=2 export DVUPLOADER_RETRY_MULTIPLIER=0.2 export DVUPLOADER_MAX_PKG_SIZE=3221225472 # 3GB +export DVUPLOADER_LOCK_WAIT_TIME=5 +export DVUPLOADER_LOCK_TIMEOUT=300 ``` **Setting programmatically:** @@ -148,6 +152,8 @@ dv.config( min_retry_time=2, retry_multiplier=0.2, max_package_size=3 * 1024**3 # 3GB + lock_wait_time=5, + lock_timeout=300, ) # Continue with your upload as normal diff --git a/dvuploader/cli.py b/dvuploader/cli.py index 9afb361..c7ab6ef 100644 --- a/dvuploader/cli.py +++ b/dvuploader/cli.py @@ -1,9 +1,10 @@ -import yaml -import typer - from pathlib import Path -from pydantic import BaseModel from typing import List, Optional + +import typer +import yaml +from pydantic import BaseModel + from dvuploader import DVUploader, File from dvuploader.utils import add_directory @@ -29,6 +30,7 @@ class CliInput(BaseModel): app = typer.Typer() + def _enumerate_filepaths(filepaths: List[str], recurse: bool) -> List[File]: """ Take a list of filepaths and transform it into a list of File objects, optionally recursing into each of them. @@ -39,7 +41,7 @@ def _enumerate_filepaths(filepaths: List[str], recurse: bool) -> List[File]: Returns: List[File]: A list of File objects representing the files extracted from all filepaths. - + Raises: FileNotFoundError: If a filepath does not exist. IsADirectoryError: If recurse is False and a filepath points to a directory instead of a file. @@ -183,6 +185,9 @@ def main( if filepaths is None: filepaths = [] + if recurse is None: + recurse = False + _validate_inputs( filepaths=filepaths, pid=pid, @@ -200,7 +205,10 @@ def main( api_token=api_token, dataverse_url=dataverse_url, persistent_id=pid, - files=_enumerate_filepaths(filepaths=filepaths, recurse=recurse), + files=_enumerate_filepaths( + filepaths=filepaths, + recurse=recurse, + ), ) uploader = DVUploader(files=cli_input.files) diff --git a/dvuploader/config.py b/dvuploader/config.py index 4203782..3336aee 100644 --- a/dvuploader/config.py +++ b/dvuploader/config.py @@ -2,6 +2,8 @@ def config( + lock_wait_time: int = 10, + lock_timeout: int = 300, max_retries: int = 15, max_retry_time: int = 240, min_retry_time: int = 1, @@ -54,3 +56,5 @@ def config( os.environ["DVUPLOADER_MIN_RETRY_TIME"] = str(min_retry_time) os.environ["DVUPLOADER_RETRY_MULTIPLIER"] = str(retry_multiplier) os.environ["DVUPLOADER_MAX_PKG_SIZE"] = str(max_package_size) + os.environ["DVUPLOADER_LOCK_WAIT_TIME"] = str(lock_wait_time) + os.environ["DVUPLOADER_LOCK_TIMEOUT"] = str(lock_timeout) diff --git a/dvuploader/directupload.py b/dvuploader/directupload.py index a8298e9..4f57b0b 100644 --- a/dvuploader/directupload.py +++ b/dvuploader/directupload.py @@ -1,21 +1,27 @@ import asyncio -import httpx -from io import BytesIO import json import os -from typing import Dict, List, Optional, Tuple +from io import BytesIO +from typing import AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urljoin + import aiofiles -from typing import AsyncGenerator +import httpx from rich.progress import Progress, TaskID from dvuploader.file import File -from dvuploader.utils import build_url +from dvuploader.utils import build_url, init_logging, wait_for_dataset_unlock TESTING = bool(os.environ.get("DVUPLOADER_TESTING", False)) MAX_FILE_DISPLAY = int(os.environ.get("DVUPLOADER_MAX_FILE_DISPLAY", 50)) MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 10)) +LOCK_WAIT_TIME = int(os.environ.get("DVUPLOADER_LOCK_WAIT_TIME", 1.5)) +LOCK_TIMEOUT = int(os.environ.get("DVUPLOADER_LOCK_TIMEOUT", 300)) + +assert isinstance(LOCK_WAIT_TIME, int), "DVUPLOADER_LOCK_WAIT_TIME must be an integer" +assert isinstance(LOCK_TIMEOUT, int), "DVUPLOADER_LOCK_TIMEOUT must be an integer" + assert isinstance(MAX_FILE_DISPLAY, int), ( "DVUPLOADER_MAX_FILE_DISPLAY must be an integer" ) @@ -27,6 +33,9 @@ UPLOAD_ENDPOINT = "/api/datasets/:persistentId/addFiles?persistentId=" REPLACE_ENDPOINT = "/api/datasets/:persistentId/replaceFiles?persistentId=" +# Initialize logging +init_logging() + async def direct_upload( files: List[File], @@ -250,7 +259,7 @@ async def _upload_singlepart( "headers": headers, "url": ticket["url"], "content": upload_bytes( - file=file.handler, # type: ignore + file=file.get_handler(), # type: ignore progress=progress, pbar=pbar, hash_func=file.checksum._hash_fun, @@ -549,6 +558,13 @@ async def _add_files_to_ds( pbar: Progress bar for registration. """ + await wait_for_dataset_unlock( + session=session, + persistent_id=pid, + sleep_time=LOCK_WAIT_TIME, + timeout=LOCK_TIMEOUT, + ) + novel_url = urljoin(dataverse_url, UPLOAD_ENDPOINT + pid) replace_url = urljoin(dataverse_url, REPLACE_ENDPOINT + pid) diff --git a/dvuploader/dvuploader.py b/dvuploader/dvuploader.py index 5d71e93..76b18ec 100644 --- a/dvuploader/dvuploader.py +++ b/dvuploader/dvuploader.py @@ -1,15 +1,15 @@ import asyncio -from urllib.parse import urljoin -import httpx import os -import rich from typing import Dict, List, Optional +from urllib.parse import urljoin +import httpx +import rich from pydantic import BaseModel -from rich.progress import Progress -from rich.table import Table from rich.console import Console from rich.panel import Panel +from rich.progress import Progress +from rich.table import Table from dvuploader.directupload import ( TICKET_ENDPOINT, @@ -239,7 +239,13 @@ def _check_duplicates( to_skip.append(file.file_id) if replace_existing: + assert file.file_id is not None, "File ID is required" + assert isinstance(file.file_id, int), "File ID must be an integer" + ds_file = self._get_dsfile_by_id(file.file_id, ds_files) + + assert ds_file is not None, "Dataset file not found" + if not self._check_size(file, ds_file): file._unchanged_data = False else: @@ -359,10 +365,12 @@ def _check_hashes(file: File, dsFile: Dict): dsFile.get("directoryLabel", ""), dsFile["dataFile"]["filename"] ) + directory_label = file.directory_label if file.directory_label else "" + return ( file.checksum.value == hash_value and file.checksum.type == hash_algo - and path == os.path.join(file.directory_label, file.file_name) # type: ignore + and path == os.path.join(directory_label, file.file_name) # type: ignore ) @staticmethod diff --git a/dvuploader/file.py b/dvuploader/file.py index 3d759e8..b7e5de5 100644 --- a/dvuploader/file.py +++ b/dvuploader/file.py @@ -130,7 +130,6 @@ def extract_file_name(self): if self.handler is None: self._validate_filepath(self.filepath) - self.handler = open(self.filepath, "rb") self._size = os.path.getsize(self.filepath) else: self._size = len(self.handler.read()) @@ -147,6 +146,15 @@ def extract_file_name(self): return self + def get_handler(self) -> IO: + """ + Opens the file and initializes the file handler. + """ + if self.handler is not None: + return self.handler + + return open(self.filepath, "rb") + @staticmethod def _validate_filepath(path): """ @@ -190,12 +198,13 @@ def update_checksum_chunked(self, blocksize=2**20): Note: This method resets the file position to the start after reading. """ - assert self.handler is not None, "File handler is not initialized." assert self.checksum is not None, "Checksum is not initialized." assert self.checksum._hash_fun is not None, "Checksum hash function is not set." + handler = self.get_handler() + while True: - buf = self.handler.read(blocksize) + buf = handler.read(blocksize) if not isinstance(buf, bytes): buf = buf.encode() @@ -203,8 +212,13 @@ def update_checksum_chunked(self, blocksize=2**20): if not buf: break self.checksum._hash_fun.update(buf) - - self.handler.seek(0) + + if self.handler is not None: # type: ignore + # In case of passed handler, we need to seek the handler to the start after reading. + self.handler.seek(0) + else: + # Path-based handlers will be opened just-in-time, so we can close it. + handler.close() def __del__(self): if self.handler is not None: diff --git a/dvuploader/nativeupload.py b/dvuploader/nativeupload.py index ea4389d..2614fbc 100644 --- a/dvuploader/nativeupload.py +++ b/dvuploader/nativeupload.py @@ -1,19 +1,24 @@ import asyncio -from io import BytesIO -from pathlib import Path -import httpx import json import os import tempfile +from io import BytesIO +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import httpx import rich import tenacity -from typing import List, Optional, Tuple, Dict - from rich.progress import Progress, TaskID from dvuploader.file import File from dvuploader.packaging import distribute_files, zip_files -from dvuploader.utils import build_url, retrieve_dataset_files +from dvuploader.utils import ( + build_url, + init_logging, + retrieve_dataset_files, + wait_for_dataset_unlock, +) ##### CONFIGURATION ##### @@ -22,6 +27,8 @@ # # This will exponentially increase the wait time between retries. # The max wait time is 240 seconds per retry though. +LOCK_WAIT_TIME = int(os.environ.get("DVUPLOADER_LOCK_WAIT_TIME", 1.5)) +LOCK_TIMEOUT = int(os.environ.get("DVUPLOADER_LOCK_TIMEOUT", 300)) MAX_RETRIES = int(os.environ.get("DVUPLOADER_MAX_RETRIES", 15)) MAX_RETRY_TIME = int(os.environ.get("DVUPLOADER_MAX_RETRY_TIME", 60)) MIN_RETRY_TIME = int(os.environ.get("DVUPLOADER_MIN_RETRY_TIME", 1)) @@ -32,6 +39,9 @@ max=MAX_RETRY_TIME, ) + +assert isinstance(LOCK_WAIT_TIME, int), "DVUPLOADER_LOCK_WAIT_TIME must be an integer" +assert isinstance(LOCK_TIMEOUT, int), "DVUPLOADER_LOCK_TIMEOUT must be an integer" assert isinstance(MAX_RETRIES, int), "DVUPLOADER_MAX_RETRIES must be an integer" assert isinstance(MAX_RETRY_TIME, int), "DVUPLOADER_MAX_RETRY_TIME must be an integer" assert isinstance(MIN_RETRY_TIME, int), "DVUPLOADER_MIN_RETRY_TIME must be an integer" @@ -55,6 +65,9 @@ ZIP_LIMIT_MESSAGE = "The number of files in the zip archive is over the limit" +init_logging() + + async def native_upload( files: List[File], dataverse_url: str, @@ -86,7 +99,12 @@ async def native_upload( session_params = { "base_url": dataverse_url, "headers": {"X-Dataverse-key": api_token}, - "timeout": None, + "timeout": httpx.Timeout( + None, + read=None, + write=None, + connect=None, + ), "limits": httpx.Limits(max_connections=n_parallel_uploads), "proxy": proxy, } @@ -262,6 +280,14 @@ async def _single_native_upload( - dict: JSON response from the upload request """ + # Check if the dataset is locked + await wait_for_dataset_unlock( + session=session, + persistent_id=persistent_id, + sleep_time=LOCK_WAIT_TIME, + timeout=LOCK_TIMEOUT, + ) + if not file.to_replace: endpoint = build_url( endpoint=NATIVE_UPLOAD_ENDPOINT, @@ -273,11 +299,12 @@ async def _single_native_upload( ) json_data = _get_json_data(file) + handler = file.get_handler() files = { "file": ( file.file_name, - file.handler, + handler, file.mimeType, ), "jsonData": ( @@ -405,6 +432,7 @@ async def _update_metadata( session=session, url=NATIVE_METADATA_ENDPOINT.format(FILE_ID=file_id), file=file, + persistent_id=persistent_id, ) tasks.append(task) @@ -420,6 +448,7 @@ async def _update_single_metadata( session: httpx.AsyncClient, url: str, file: File, + persistent_id: str, ) -> None: """ Updates the metadata of a single file in a Dataverse repository. @@ -433,6 +462,13 @@ async def _update_single_metadata( ValueError: If metadata update fails. """ + await wait_for_dataset_unlock( + session=session, + persistent_id=persistent_id, + sleep_time=LOCK_WAIT_TIME, + timeout=LOCK_TIMEOUT, + ) + json_data = _get_json_data(file) # Send metadata as a readable byte stream @@ -453,7 +489,16 @@ async def _update_single_metadata( else: await asyncio.sleep(1.0) - raise ValueError(f"Failed to update metadata for file {file.file_name}.") + if "message" in response.json(): + # If the response is a JSON object, we can get the error message from the "message" key. + error_message = response.json()["message"] + else: + # If the response is not a JSON object, we can get the error message from the response text. + error_message = response.text + + raise ValueError( + f"Failed to update metadata for file {file.file_name}: {error_message}" + ) def _retrieve_file_ids( diff --git a/dvuploader/packaging.py b/dvuploader/packaging.py index 6d40f41..3e7f49c 100644 --- a/dvuploader/packaging.py +++ b/dvuploader/packaging.py @@ -1,7 +1,7 @@ import os import zipfile - from typing import List, Tuple + from dvuploader.file import File MAXIMUM_PACKAGE_SIZE = int( @@ -101,7 +101,7 @@ def zip_files( with zipfile.ZipFile(path, "w") as zip_file: for file in files: zip_file.writestr( - data=file.handler.read(), # type: ignore + data=file.get_handler().read(), # type: ignore zinfo_or_arcname=_create_arcname(file), ) file._is_inside_zip = True @@ -123,4 +123,5 @@ def _create_arcname(file: File) -> str: if file.directory_label is not None: return os.path.join(file.directory_label, file.file_name) # type: ignore else: + assert file.file_name is not None, "File name is required" return file.file_name diff --git a/dvuploader/utils.py b/dvuploader/utils.py index b06337b..e53220d 100644 --- a/dvuploader/utils.py +++ b/dvuploader/utils.py @@ -1,14 +1,33 @@ +import asyncio +import logging import os import pathlib import re +import time from typing import List from urllib.parse import urljoin + import httpx from rich.progress import Progress from dvuploader.file import File +def init_logging(): + level = ( + logging.DEBUG + if os.environ.get("DVUPLOADER_DEBUG", "false").lower() == "true" + else None + ) + + if level is not None: + logging.basicConfig( + format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=level, + ) + + def build_url( endpoint: str, **kwargs, @@ -173,3 +192,70 @@ def setup_pbar( start=True, total=file_size, ) + + +async def wait_for_dataset_unlock( + session: httpx.AsyncClient, + persistent_id: str, + sleep_time: float = 1.0, + timeout: float = 300.0, # 5 minutes +) -> None: + """ + Wait for a dataset to be unlocked. + + Args: + session (httpx.AsyncClient): The httpx client. + persistent_id (str): The persistent identifier of the dataset. + sleep_time (float): The time to sleep between checks. + timeout (float): The timeout in seconds. + """ + dataset_id = await _get_dataset_id( + session=session, + persistent_id=persistent_id, + ) + start_time = time.monotonic() + while await check_dataset_lock(session=session, id=dataset_id): + if time.monotonic() - start_time > timeout: + raise TimeoutError(f"Dataset {id} did not unlock after {timeout} seconds") + await asyncio.sleep(sleep_time) + + +async def check_dataset_lock( + session: httpx.AsyncClient, + id: int, +) -> bool: + """ + Check if a dataset is locked. + + Args: + session (httpx.AsyncClient): The httpx client. + id (int): The ID of the dataset. + + Returns: + bool: True if the dataset is locked, False otherwise. + """ + response = await session.get(f"/api/datasets/{id}/locks") + response.raise_for_status() + + body = response.json() + if len(body["data"]) == 0: + return False + return True + + +async def _get_dataset_id( + session: httpx.AsyncClient, + persistent_id: str, +) -> int: + """ + Get the ID of a dataset. + + Args: + session (httpx.AsyncClient): The httpx client. + persistent_id (str): The persistent identifier of the dataset. + """ + response = await session.get( + f"/api/datasets/:persistentId/?persistentId={persistent_id}" + ) + response.raise_for_status() + return response.json()["data"]["id"] diff --git a/tests/conftest.py b/tests/conftest.py index 84e5cb2..6437e4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,9 @@ import os -import pytest -import httpx import random +from typing import Literal, Tuple, Union, overload + +import httpx +import pytest @pytest.fixture @@ -16,11 +18,30 @@ def credentials(): return BASE_URL, API_TOKEN +@overload def create_dataset( parent: str, server_url: str, api_token: str, -): + return_id: Literal[False] = False, +) -> str: ... + + +@overload +def create_dataset( + parent: str, + server_url: str, + api_token: str, + return_id: Literal[True], +) -> Tuple[str, int]: ... + + +def create_dataset( + parent: str, + server_url: str, + api_token: str, + return_id: bool = False, +) -> Union[str, Tuple[str, int]]: """ Creates a dataset in a Dataverse. @@ -30,7 +51,7 @@ def create_dataset( api_token (str): The API token for authentication. Returns: - str: The persistent identifier of the created dataset. + Dict: The response from the Dataverse API. """ if server_url.endswith("/"): server_url = server_url[:-1] @@ -39,12 +60,15 @@ def create_dataset( response = httpx.post( url=url, headers={"X-Dataverse-key": api_token}, - data=open("./tests/fixtures/create_dataset.json", "rb"), + data=open("./tests/fixtures/create_dataset.json", "rb"), # type: ignore ) response.raise_for_status() - return response.json()["data"]["persistentId"] + if return_id: + return response.json()["data"]["persistentId"], response.json()["data"]["id"] + else: + return response.json()["data"]["persistentId"] def create_mock_file( diff --git a/tests/integration/test_native_upload.py b/tests/integration/test_native_upload.py index 0c23260..4eea36a 100644 --- a/tests/integration/test_native_upload.py +++ b/tests/integration/test_native_upload.py @@ -1,13 +1,12 @@ -from io import BytesIO import json import os import tempfile +from io import BytesIO import pytest from dvuploader.dvuploader import DVUploader from dvuploader.file import File - from dvuploader.utils import add_directory, retrieve_dataset_files from tests.conftest import create_dataset, create_mock_file, create_mock_tabular_file @@ -472,15 +471,17 @@ def test_metadata_with_zip_files_in_package(self, credentials): # Arrange files = [ - File(filepath="tests/fixtures/archive.zip", - dv_dir="subdir2", - description="This file should not be unzipped", - categories=["Test file"] + File( + filepath="tests/fixtures/archive.zip", + directoryLabel="subdir2", + description="This file should not be unzipped", + categories=["Test file"], ), - File(filepath="tests/fixtures/add_dir_files/somefile.txt", - dv_dir="subdir", - description="A simple text file", - categories=["Test file"] + File( + filepath="tests/fixtures/add_dir_files/somefile.txt", + directoryLabel="subdir", + description="A simple text file", + categories=["Test file"], ), ] @@ -506,30 +507,26 @@ def test_metadata_with_zip_files_in_package(self, credentials): { "label": "archive.zip", "description": "This file should not be unzipped", - "categories": ["Test file"] + "categories": ["Test file"], }, { "label": "somefile.txt", "description": "A simple text file", - "categories": ["Test file"] + "categories": ["Test file"], }, ] - files_as_expected = sorted( + files_as_expected = sorted( # pyright: ignore[reportCallIssue] [ - { - k: (f[k] if k in f else None) - for k in expected_files[0].keys() - } + {k: (f[k] if k in f else None) for k in expected_files[0].keys()} for f in files ], - key=lambda x: x["label"] + key=lambda x: x["label"], # pyright: ignore[reportArgumentType] ) assert files_as_expected == expected_files, ( f"File metadata not as expected: {json.dumps(files, indent=2)}" ) - def test_too_many_zip_files( self, credentials, diff --git a/tests/unit/test_directupload.py b/tests/unit/test_directupload.py index 2136832..cef0e0d 100644 --- a/tests/unit/test_directupload.py +++ b/tests/unit/test_directupload.py @@ -1,27 +1,36 @@ import httpx import pytest from rich.progress import Progress + from dvuploader.directupload import ( _add_files_to_ds, - _validate_ticket_response, _prepare_registration, + _validate_ticket_response, ) - from dvuploader.file import File class Test_AddFileToDs: - # Should successfully add files to a Dataverse dataset with a valid file path @pytest.mark.asyncio async def test_successfully_add_file_with_valid_filepath(self, httpx_mock): - # Mock the session.post method to return a response with status code 200 + httpx_mock.add_response( + method="get", + url="https://example.com/api/datasets/:persistentId/?persistentId=pid", + json={"status": "OK", "data": {"id": 123}}, + ) + + httpx_mock.add_response( + method="get", + url="https://example.com/api/datasets/123/locks", + json={"status": "OK", "data": []}, + ) + httpx_mock.add_response( method="post", url="https://example.com/api/datasets/:persistentId/addFiles?persistentId=pid", ) - # Initialize the necessary variables - session = httpx.AsyncClient() + session = httpx.AsyncClient(base_url="https://example.com") dataverse_url = "https://example.com" pid = "pid" fpath = "tests/fixtures/add_dir_files/somefile.txt" @@ -29,7 +38,6 @@ async def test_successfully_add_file_with_valid_filepath(self, httpx_mock): progress = Progress() pbar = progress.add_task("Uploading", total=1) - # Invoke the function await _add_files_to_ds( session=session, dataverse_url=dataverse_url, @@ -41,14 +49,24 @@ async def test_successfully_add_file_with_valid_filepath(self, httpx_mock): @pytest.mark.asyncio async def test_successfully_replace_file_with_valid_filepath(self, httpx_mock): - # Mock the session.post method to return a response with status code 200 + httpx_mock.add_response( + method="get", + url="https://example.com/api/datasets/:persistentId/?persistentId=pid", + json={"status": "OK", "data": {"id": 123}}, + ) + + httpx_mock.add_response( + method="get", + url="https://example.com/api/datasets/123/locks", + json={"status": "OK", "data": []}, + ) + httpx_mock.add_response( method="post", url="https://example.com/api/datasets/:persistentId/replaceFiles?persistentId=pid", ) - # Initialize the necessary variables - session = httpx.AsyncClient() + session = httpx.AsyncClient(base_url="https://example.com") dataverse_url = "https://example.com" pid = "pid" fpath = "tests/fixtures/add_dir_files/somefile.txt" @@ -56,7 +74,6 @@ async def test_successfully_replace_file_with_valid_filepath(self, httpx_mock): progress = Progress() pbar = progress.add_task("Uploading", total=1) - # Invoke the function await _add_files_to_ds( session=session, dataverse_url=dataverse_url, @@ -70,10 +87,16 @@ async def test_successfully_replace_file_with_valid_filepath(self, httpx_mock): async def test_successfully_add_and_replace_file_with_valid_filepath( self, httpx_mock ): - # Mock the session.post method to return a response with status code 200 httpx_mock.add_response( - method="post", - url="https://example.com/api/datasets/:persistentId/replaceFiles?persistentId=pid", + method="get", + url="https://example.com/api/datasets/:persistentId/?persistentId=pid", + json={"status": "OK", "data": {"id": 123}}, + ) + + httpx_mock.add_response( + method="get", + url="https://example.com/api/datasets/123/locks", + json={"status": "OK", "data": []}, ) httpx_mock.add_response( @@ -81,8 +104,12 @@ async def test_successfully_add_and_replace_file_with_valid_filepath( url="https://example.com/api/datasets/:persistentId/addFiles?persistentId=pid", ) - # Initialize the necessary variables - session = httpx.AsyncClient() + httpx_mock.add_response( + method="post", + url="https://example.com/api/datasets/:persistentId/replaceFiles?persistentId=pid", + ) + + session = httpx.AsyncClient(base_url="https://example.com") dataverse_url = "https://example.com" pid = "pid" fpath = "tests/fixtures/add_dir_files/somefile.txt" @@ -93,7 +120,6 @@ async def test_successfully_add_and_replace_file_with_valid_filepath( progress = Progress() pbar = progress.add_task("Uploading", total=1) - # Invoke the function await _add_files_to_ds( session=session, dataverse_url=dataverse_url, diff --git a/tests/unit/test_file.py b/tests/unit/test_file.py index 5232cff..8475822 100644 --- a/tests/unit/test_file.py +++ b/tests/unit/test_file.py @@ -1,4 +1,5 @@ import pytest + from dvuploader.file import File @@ -10,7 +11,7 @@ def test_read_file(self): # Act file = File( filepath=fpath, - directory_label="", + directoryLabel="", ) file.extract_file_name() @@ -26,7 +27,7 @@ def test_read_non_existent_file(self): with pytest.raises(FileNotFoundError): file = File( filepath=fpath, - directory_label="", + directoryLabel="", ) file.extract_file_name() @@ -39,7 +40,7 @@ def test_read_non_file(self): with pytest.raises(IsADirectoryError): file = File( filepath=fpath, - directory_label="", + directoryLabel="", ) file.extract_file_name() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 45d8a1c..a68291d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,15 +1,21 @@ +import asyncio from io import BytesIO -import pytest -import httpx +import httpx +import pytest from rich.progress import Progress + from dvuploader.file import File from dvuploader.utils import ( + _get_dataset_id, add_directory, build_url, + check_dataset_lock, retrieve_dataset_files, setup_pbar, + wait_for_dataset_unlock, ) +from tests.conftest import create_dataset class TestAddDirectory: @@ -188,3 +194,214 @@ def test_returns_progress_bar_object(self): # Assert assert isinstance(result, int) assert result == 0 + + +class TestDatasetId: + @pytest.mark.asyncio + async def test_get_dataset_id(self, credentials): + # Arrange + BASE_URL, API_TOKEN = credentials + dataset_pid, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + + print(dataset_pid, dataset_id) + + # Act + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + result = await _get_dataset_id(session=session, persistent_id=dataset_pid) + + # Assert + assert result == dataset_id + + +class TestCheckDatasetLock: + @pytest.mark.asyncio + async def test_check_dataset_lock(self, credentials): + # Create a dataset and apply a lock, then verify the lock is detected + BASE_URL, API_TOKEN = credentials + _, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + response = await session.post(f"/api/datasets/{dataset_id}/lock/Ingest") + response.raise_for_status() + result = await check_dataset_lock(session=session, id=dataset_id) + assert result is True + + @pytest.mark.asyncio + async def test_wait_for_dataset_unlock(self, credentials): + # Test that the unlock wait function completes when a dataset lock is released + BASE_URL, API_TOKEN = credentials + dataset_pid, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + response = await session.post(f"/api/datasets/{dataset_id}/lock/Ingest") + response.raise_for_status() + + async def release_lock(): + # Simulate background unlock after a brief pause + await asyncio.sleep(1.5) + unlock_resp = await session.delete( + f"/api/datasets/{dataset_id}/locks", + params={"type": "Ingest"}, + ) + unlock_resp.raise_for_status() + + release_task = asyncio.create_task(release_lock()) + await wait_for_dataset_unlock( + session=session, + persistent_id=dataset_pid, + timeout=4, + ) + await release_task # Ensure unlock task completes + + @pytest.mark.asyncio + async def test_wait_for_dataset_unlock_timeout(self, credentials): + # Should raise a timeout error if dataset is not unlocked within the given window + BASE_URL, API_TOKEN = credentials + dataset_pid, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + response = await session.post(f"/api/datasets/{dataset_id}/lock/Ingest") + response.raise_for_status() + + with pytest.raises(TimeoutError): + await wait_for_dataset_unlock( + session=session, + persistent_id=dataset_pid, + timeout=0.2, + ) + + @pytest.mark.asyncio + async def test_check_dataset_lock_when_unlocked(self, credentials): + # Confirm that check_dataset_lock returns False for unlocked datasets + BASE_URL, API_TOKEN = credentials + dataset_pid, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + result = await check_dataset_lock(session=session, id=dataset_id) + assert result is False + + @pytest.mark.asyncio + async def test_wait_for_dataset_unlock_already_unlocked(self, credentials): + # Wait should return promptly when there is no lock present + BASE_URL, API_TOKEN = credentials + dataset_pid, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + import time + + start = time.monotonic() + await wait_for_dataset_unlock( + session=session, + persistent_id=dataset_pid, + timeout=5, + ) + elapsed = time.monotonic() - start + assert elapsed < 0.5 # Operation should be quick + + @pytest.mark.asyncio + async def test_check_dataset_lock_invalid_id(self, credentials): + # Using a likely-invalid ID should cause an HTTP error from the API + BASE_URL, API_TOKEN = credentials + invalid_dataset_id = 999999999 + + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + with pytest.raises(httpx.HTTPStatusError): + await check_dataset_lock(session=session, id=invalid_dataset_id) + + @pytest.mark.asyncio + async def test_wait_for_dataset_unlock_invalid_id(self, credentials): + # Waiting on an invalid dataset should raise an HTTP error + BASE_URL, API_TOKEN = credentials + invalid_dataset_pid = "999999999" + + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + with pytest.raises(httpx.HTTPStatusError): + await wait_for_dataset_unlock( + session=session, + persistent_id=invalid_dataset_pid, + timeout=1, + ) + + @pytest.mark.asyncio + async def test_wait_for_dataset_unlock_race_condition_at_timeout(self, credentials): + # Test the case where unlocking occurs just before timeout + BASE_URL, API_TOKEN = credentials + dataset_pid, dataset_id = create_dataset( + parent="Root", + server_url=BASE_URL, + api_token=API_TOKEN, + return_id=True, + ) + async with httpx.AsyncClient( + base_url=BASE_URL, + headers={"X-Dataverse-key": API_TOKEN}, + ) as session: + response = await session.post(f"/api/datasets/{dataset_id}/lock/Ingest") + response.raise_for_status() + + async def release_lock(): + # Unlock just before the test timeout + await asyncio.sleep(1.8) + unlock_resp = await session.delete( + f"/api/datasets/{dataset_id}/locks", + params={"type": "Ingest"}, + ) + unlock_resp.raise_for_status() + + release_task = asyncio.create_task(release_lock()) + await wait_for_dataset_unlock( + session=session, + persistent_id=dataset_pid, + timeout=2.5, + sleep_time=0.1, + ) + await release_task # Clean up after test unlock