Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions tests/core/test_download_orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import pytest
import asyncio
import time
from pathlib import Path
from unittest.mock import MagicMock, AsyncMock, patch

# Make sure these import paths are correct for your project structure
from forklet.core.orchestrator import DownloadOrchestrator
from forklet.models import GitHubFile, DownloadStatus

# --- Test Fixtures for Setup ---

@pytest.fixture
def mock_services():
"""Creates mock objects for services used by the orchestrator."""
github_service = MagicMock()
download_service = MagicMock()
github_service.get_repository_tree = AsyncMock()
github_service.get_file_content = AsyncMock()
download_service.save_content = AsyncMock(return_value=128)
download_service.ensure_directory = AsyncMock()
return github_service, download_service

@pytest.fixture
def orchestrator(mock_services):
"""Initializes the DownloadOrchestrator with mocked services."""
github_service, download_service = mock_services
orchestrator_instance = DownloadOrchestrator(
github_service=github_service,
download_service=download_service,
max_concurrent_downloads=5
)
orchestrator_instance.reset_state()
return orchestrator_instance

@pytest.fixture
def mock_request():
"""Creates a mock DownloadRequest object for use in tests."""
request = MagicMock()
request.repository.owner = "test-owner"
request.repository.name = "test-repo"
request.repository.display_name = "test-owner/test-repo"
request.git_ref = "main"
request.filters = MagicMock()
request.filters.include_patterns = []
request.filters.exclude_patterns = []
request.destination = Path("/fake/destination")
request.create_destination = True
request.overwrite_existing = False
request.preserve_structure = True
request.show_progress_bars = False
return request

# --- Test Cases ---

class TestDownloadOrchestrator:

def test_initialization_sets_properties_correctly(self, orchestrator):
"""Verify that max_concurrent_downloads is correctly set."""
assert orchestrator.max_concurrent_downloads == 5
assert orchestrator._semaphore._value == 5
assert not orchestrator._is_cancelled

@pytest.mark.asyncio
async def test_execute_download_success(self, orchestrator, mock_services, mock_request):
"""Simulate a successful download with mocked services."""
github_service, _ = mock_services
mock_file_list = [MagicMock(spec=GitHubFile, path="file1.txt", size=100)]
github_service.get_repository_tree.return_value = mock_file_list

with patch.object(orchestrator, '_download_files_concurrently', new_callable=AsyncMock) as mock_downloader, \
patch('forklet.core.orchestrator.FilterEngine') as mock_filter_engine:

mock_downloader.return_value = (["file1.txt"], {})
mock_filter_engine.return_value.filter_files.return_value.included_files = mock_file_list

result = await orchestrator.execute_download(request=mock_request)

mock_downloader.assert_awaited_once()
assert result.status == DownloadStatus.COMPLETED

@pytest.mark.asyncio
async def test_execute_download_repo_fetch_fails(self, orchestrator, mock_services, mock_request):
"""Test error handling when repository tree fetch fails."""
github_service, _ = mock_services
github_service.get_repository_tree.side_effect = Exception("API limit reached")

result = await orchestrator.execute_download(request=mock_request)

assert result.status == DownloadStatus.FAILED
assert "API limit reached" in result.error_message

def test_cancel_sets_flag_and_logs(self, orchestrator):
"""Test cancel() -> sets _is_cancelled=True and logs when a download is active."""
orchestrator._current_result = MagicMock()

with patch('forklet.core.orchestrator.logger') as mock_logger:
orchestrator.cancel()
assert orchestrator._is_cancelled is True
mock_logger.info.assert_called_with("Download cancelled by user")

@pytest.mark.asyncio
async def test_pause_and_resume_flow(self, orchestrator, mock_services, mock_request):
"""Tests the full pause and resume flow in a stable, controlled manner."""
github_service, _ = mock_services
# --- THIS IS THE FIX ---
# The mock file MUST have a 'size' attribute for the sum() calculation.
mock_file_list = [MagicMock(spec=GitHubFile, path="file1.txt", size=100)]
# -----------------------
github_service.get_repository_tree.return_value = mock_file_list

download_can_complete = asyncio.Event()

async def wait_for_signal_to_finish(*args, **kwargs):
await download_can_complete.wait()
return (["file1.txt"], {})

with patch.object(orchestrator, '_download_files_concurrently', side_effect=wait_for_signal_to_finish), \
patch('forklet.core.orchestrator.FilterEngine') as mock_filter_engine:

mock_filter_engine.return_value.filter_files.return_value.included_files = mock_file_list

download_task = asyncio.create_task(orchestrator.execute_download(mock_request))

await asyncio.sleep(0.01)

if download_task.done() and download_task.exception():
raise download_task.exception()

assert orchestrator._current_result is not None, "Orchestrator._current_result was not set."

await orchestrator.pause()
assert orchestrator._is_paused is True
assert orchestrator._current_result.status == DownloadStatus.PAUSED

await orchestrator.resume()
assert orchestrator._is_paused is False
assert orchestrator._current_result.status == DownloadStatus.IN_PROGRESS

download_can_complete.set()

final_result = await download_task
assert final_result.status == DownloadStatus.COMPLETED

def test_get_current_progress_returns_none_when_inactive(self, orchestrator):
"""Test get_current_progress() -> returns None when no download is active."""
assert orchestrator.get_current_progress() is None
44 changes: 17 additions & 27 deletions tests/infrastructure/test_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
# Adjust this import path to match your project's structure
from forklet.infrastructure.rate_limiter import RateLimiter, RateLimitInfo

# Mark all tests in this file as asyncio
# pytestmark = pytest.mark.asyncio


## 1. RateLimitInfo Helper Class Tests
Expand All @@ -33,20 +31,16 @@ def test_rate_limit_info_reset_in_seconds():
"""Test the calculation of seconds until reset."""
info = RateLimitInfo()

# Mock datetime.now() to control the current time for the test
mock_now = datetime(2025, 10, 2, 12, 0, 0)
with patch('forklet.infrastructure.rate_limiter.datetime', autospec=True) as mock_datetime:
mock_datetime.now.return_value = mock_now

# Set reset_time 30 seconds into the future
info.reset_time = mock_now + timedelta(seconds=30)
assert info.reset_in_seconds == 30.0

# Set reset_time in the past
info.reset_time = mock_now - timedelta(seconds=30)
assert info.reset_in_seconds == 0.0, "Should not return negative time"

# No reset time set
info.reset_time = None
assert info.reset_in_seconds == 0.0

Expand Down Expand Up @@ -86,7 +80,6 @@ async def test_update_rate_limit_info_sets_values_correctly():
assert info.reset_time == datetime.fromtimestamp(reset_timestamp)
assert not info.is_exhausted


@pytest.mark.asyncio
async def test_update_rate_limit_increments_consecutive_limits():
"""Test that _consecutive_limits is handled correctly."""
Expand All @@ -108,51 +101,53 @@ async def test_update_rate_limit_increments_consecutive_limits():
# ------------------------------------------
@pytest.mark.asyncio
@patch('asyncio.sleep', new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_acquire_waits_when_primary_rate_limit_exhausted(mock_sleep):
"""Test that acquire() waits for reset_in_seconds when exhausted."""
rl = RateLimiter()

# Mock datetime.now() to control time
mock_now = datetime(2025, 10, 2, 12, 0, 0)
with patch('forklet.infrastructure.rate_limiter.datetime', autospec=True) as mock_datetime:
mock_datetime.now.return_value = mock_now

# Set state to exhausted, with reset 15 seconds in the future
rl.rate_limit_info.remaining = 5
rl.rate_limit_info.reset_time = mock_now + timedelta(seconds=15)

await rl.acquire()

# Check that it slept for the primary rate limit duration
mock_sleep.assert_any_call(15.0)

@pytest.mark.asyncio
@patch('asyncio.sleep', new_callable=AsyncMock)
@pytest.mark.asyncio
async def test_acquire_uses_adaptive_delay(mock_sleep):
"""Test that acquire() uses the calculated adaptive delay."""
"""Test that acquire() uses the calculated adaptive delay on the second call."""
rl = RateLimiter(default_delay=1.0)

# Mock time.time() for delay calculation
with patch('time.time', side_effect=[1000.0, 1000.1]):
# Mock time.time() to simulate time passing
with patch('time.time', side_effect=[1000.0, 1000.1, 1000.2, 1000.3]) as mock_time:
# Ensure rate limit is not exhausted
rl.rate_limit_info.remaining = 2000
rl.rate_limit_info.remaining = 2000

# FIRST call: This sets _last_request, but calculates a delay of 0.
await rl.acquire()
mock_sleep.assert_not_called() # No sleep on the first call
assert rl._last_request == 1000.1

# SECOND call: This call is close to the first one, triggering the delay.
await rl.acquire()

# Check that sleep was called. The exact value has jitter, so we check if it was called.
# mock_sleep.assert_called()
# The first call to time.time() is at the start of acquire(),
# the second is for _last_request. The delay calculation uses the first one.
# Expected delay is around 1.0 seconds.
# assert mock_sleep.call_args[0][0] > 0.5
# Assert that sleep was finally called on the second run
mock_sleep.assert_called()
# The delay should be > 0 because elapsed time (0.1s) < default_delay (1.0s)
assert mock_sleep.call_args[0][0] > 0

@pytest.mark.asyncio
async def test_acquire_updates_last_request_time():
"""Test that acquire() correctly updates the _last_request timestamp."""
rl = RateLimiter()

with patch('time.time', return_value=12345.0):
# Patch sleep to make the test run instantly
with patch('time.time', return_value=12345.0) as mock_time:
with patch('asyncio.sleep'):
await rl.acquire()
assert rl._last_request == 12345.0
Expand All @@ -167,11 +162,9 @@ async def test_update_rate_limit_info_is_task_safe():
num_tasks = 50

async def worker(headers):
# Add a small, random delay to increase the chance of race conditions if unlocked
await asyncio.sleep(0.01 * random.random())
await rl.update_rate_limit_info(headers)

# Create many concurrent tasks
all_headers = []
for i in range(num_tasks):
headers = {
Expand All @@ -183,12 +176,9 @@ async def worker(headers):
tasks = [asyncio.create_task(worker(h)) for h in all_headers]
await asyncio.gather(*tasks)

# The final state should be internally consistent, belonging to one of the updates.
# If limit is 5000+i, remaining must be 4000+i.
final_limit = rl.rate_limit_info.limit
final_remaining = rl.rate_limit_info.remaining

# Calculate what 'i' must have been based on the final limit
i = final_limit - 5000
expected_remaining = 4000 + i
assert final_remaining == expected_remaining, "Inconsistent state suggests a race condition"