Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions streamrip/client/downloadable.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,36 @@ async def fast_async_download(path, url, headers, callback):
Using aiofiles/aiohttp resulted in a yield to the event loop for every 1KB,
which made file downloads CPU-bound. This resulted in a ~10MB max total download
speed. This fixes the issue by only yielding to the event loop for every 1MB read.

Supports resuming interrupted downloads via the Range header.
"""
chunk_size: int = 2**17 # 131 KB
counter = 0
yield_every = 8 # 1 MB
with open(path, "wb") as file: # noqa: ASYNC101
with requests.get( # noqa: ASYNC100
url,
headers=headers,
allow_redirects=True,
stream=True,
) as resp:

resume_pos = os.path.getsize(path) if os.path.exists(path) else 0
req_headers = dict(headers)
if resume_pos > 0:
req_headers["Range"] = f"bytes={resume_pos}-"

with requests.get( # noqa: ASYNC100
url,
headers=req_headers,
allow_redirects=True,
stream=True,
) as resp:
# 416 = Range Not Satisfiable → file already complete
if resume_pos > 0 and resp.status_code == 416:
return
resp.raise_for_status()
# Only append if the server honored our Range request (206). If it
# returned 200, it's sending the whole file from the start, so we must
# overwrite — otherwise the existing partial bytes would be duplicated.
if resume_pos > 0 and resp.status_code == 206:
open_mode = "ab"
else:
open_mode = "wb"
with open(path, open_mode) as file: # noqa: ASYNC101
for chunk in resp.iter_content(chunk_size=chunk_size):
file.write(chunk)
callback(len(chunk))
Expand Down
59 changes: 28 additions & 31 deletions streamrip/media/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,34 @@ async def preprocess(self):
async def download(self):
# TODO: progress bar description
async with global_download_semaphore(self.config.session.downloads):
with get_progress_callback(
self.config.session.cli.progress_bars,
await self.downloadable.size(),
f"Track {self.meta.tracknumber}",
) as callback:
try:
await self.downloadable.download(self.download_path, callback)
retry = False
except Exception as e:
logger.error(
f"Error downloading track '{self.meta.title}', retrying: {e}"
)
retry = True

if not retry:
return

with get_progress_callback(
self.config.session.cli.progress_bars,
await self.downloadable.size(),
f"Track {self.meta.tracknumber} (retry)",
) as callback:
try:
await self.downloadable.download(self.download_path, callback)
except Exception as e:
logger.error(
f"Persistent error downloading track '{self.meta.title}', skipping: {e}"
)
self.db.set_failed(
self.downloadable.source, "track", self.meta.info.id
)
max_attempts = 4
for attempt in range(1, max_attempts + 1):
label = (
f"Track {self.meta.tracknumber}"
if attempt == 1
else f"Track {self.meta.tracknumber} (retry {attempt - 1})"
)
with get_progress_callback(
self.config.session.cli.progress_bars,
await self.downloadable.size(),
label,
) as callback:
try:
await self.downloadable.download(self.download_path, callback)
break
except Exception as e:
if attempt < max_attempts:
logger.error(
f"Error downloading track '{self.meta.title}', retrying ({attempt}/{max_attempts - 1}): {e}"
)
await asyncio.sleep(2**attempt) # 2s, 4s, 8s
else:
logger.error(
f"Persistent error downloading track '{self.meta.title}', skipping: {e}"
)
self.db.set_failed(
self.downloadable.source, "track", self.meta.info.id
)

async def postprocess(self):
if self.is_single:
Expand Down
129 changes: 129 additions & 0 deletions tests/test_resumable_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Tests for resumable downloads in fast_async_download.

These spin up a local HTTP server that interrupts the first response
mid-stream (reproducing the IncompleteRead errors seen on large tracks)
and verify that a subsequent attempt resumes and produces a byte-for-byte
correct file.
"""
import hashlib
import http.server
import os
import threading

import pytest

from streamrip.client.downloadable import fast_async_download


def _payload(n: int) -> bytes:
return bytes((i * 31 + 7) % 256 for i in range(n))


def _start_server(handler_cls):
srv = http.server.HTTPServer(("127.0.0.1", 0), handler_cls)
port = srv.server_address[1]
threading.Thread(target=srv.serve_forever, daemon=True).start()
return srv, port


@pytest.mark.asyncio
async def test_resume_with_range_support(tmp_path):
"""Server drops the connection mid-stream, then honors Range with 206."""
total = 1_000_000
payload = _payload(total)
full_sha = hashlib.sha256(payload).hexdigest()
request_ranges = []

class Handler(http.server.BaseHTTPRequestHandler):
def log_message(self, *a):
pass

def do_GET(self):
rng = self.headers.get("Range")
request_ranges.append(rng)
if rng is None:
self.send_response(200)
self.send_header("Content-Length", str(total))
self.end_headers()
self.wfile.write(payload[: total // 2])
self.wfile.flush()
self.close_connection = True
self.connection.close()
else:
start = int(rng.replace("bytes=", "").split("-")[0])
self.send_response(206)
self.send_header("Content-Range", f"bytes {start}-{total-1}/{total}")
self.send_header("Content-Length", str(total - start))
self.end_headers()
self.wfile.write(payload[start:])

srv, port = _start_server(Handler)
url = f"http://127.0.0.1:{port}/track.flac"
path = str(tmp_path / "track.flac")

last_exc = None
for _ in range(4):
try:
await fast_async_download(path, url, {}, lambda n: None)
last_exc = None
break
except Exception as e: # noqa: BLE001
last_exc = e

srv.shutdown()

assert last_exc is None
assert os.path.getsize(path) == total
with open(path, "rb") as f:
assert hashlib.sha256(f.read()).hexdigest() == full_sha
# The retry sent a Range header (resume actually happened).
assert any(r and r.startswith("bytes=") for r in request_ranges)


@pytest.mark.asyncio
async def test_resume_falls_back_when_range_ignored(tmp_path):
"""If the server ignores Range and replies 200 with the full body,
the partial file must be overwritten, not appended to."""
total = 600_000
payload = _payload(total)
full_sha = hashlib.sha256(payload).hexdigest()
calls = []

class Handler(http.server.BaseHTTPRequestHandler):
def log_message(self, *a):
pass

def do_GET(self):
calls.append(self.headers.get("Range"))
if len(calls) == 1:
self.send_response(200)
self.send_header("Content-Length", str(total))
self.end_headers()
self.wfile.write(payload[: total // 2])
self.wfile.flush()
self.close_connection = True
self.connection.close()
else:
# Range ignored: send the whole file from the start.
self.send_response(200)
self.send_header("Content-Length", str(total))
self.end_headers()
self.wfile.write(payload)

srv, port = _start_server(Handler)
url = f"http://127.0.0.1:{port}/x.flac"
path = str(tmp_path / "x.flac")

for _ in range(4):
try:
await fast_async_download(path, url, {}, lambda n: None)
break
except Exception: # noqa: BLE001
pass

srv.shutdown()

# No duplicated bytes despite the partial file already existing.
assert os.path.getsize(path) == total
with open(path, "rb") as f:
assert hashlib.sha256(f.read()).hexdigest() == full_sha