Skip to content
This repository was archived by the owner on Jan 19, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 110 additions & 60 deletions lyrics_transcriber/transcribers/audioshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class AudioShakeConfig:
"""Configuration for AudioShake transcription service."""

api_token: Optional[str] = None
base_url: str = "https://groovy.audioshake.ai"
base_url: str = "https://api.audioshake.ai"
output_prefix: Optional[str] = None
timeout_minutes: int = 10 # Added timeout configuration
timeout_minutes: int = 20 # Added timeout configuration


class AudioShakeAPI:
Expand All @@ -34,46 +34,52 @@ def _validate_config(self) -> None:
def _get_headers(self) -> Dict[str, str]:
"""Get headers for API requests."""
self._validate_config() # Validate before making any API calls
return {"Authorization": f"Bearer {self.config.api_token}", "Content-Type": "application/json"}
return {"x-api-key": self.config.api_token, "Content-Type": "application/json"}

def upload_file(self, filepath: str) -> str:
"""Upload audio file and return asset ID."""
"""Upload audio file and return file URL."""
self.logger.info(f"Uploading {filepath} to AudioShake")
self._validate_config() # Validate before making API call

url = f"{self.config.base_url}/upload/"
with open(filepath, "rb") as file:
files = {"file": (os.path.basename(filepath), file)}
response = requests.post(url, headers={"Authorization": self._get_headers()["Authorization"]}, files=files)
response = requests.post(url, headers={"x-api-key": self.config.api_token}, files=files)

self.logger.debug(f"Upload response: {response.status_code} - {response.text}")
response.raise_for_status()
return response.json()["id"]
return response.json()["link"]

def create_job(self, asset_id: str) -> str:
"""Create transcription job and return job ID."""
self.logger.info(f"Creating job for asset {asset_id}")
def create_task(self, file_url: str) -> str:
"""Create transcription task and return task ID."""
self.logger.info(f"Creating task for file {file_url}")

url = f"{self.config.base_url}/job/"
url = f"{self.config.base_url}/tasks"
data = {
"metadata": {"format": "json", "name": "alignment", "language": "en"},
"callbackUrl": "https://example.com/webhook/alignment",
"assetId": asset_id,
"url": file_url,
"targets": [
{
"model": "alignment",
"formats": ["json"],
"language": "en"
}
],
}
response = requests.post(url, headers=self._get_headers(), json=data)
response.raise_for_status()
return response.json()["job"]["id"]
return response.json()["id"]

def wait_for_job_result(self, job_id: str) -> Dict[str, Any]:
"""Poll for job completion and return results."""
self.logger.info(f"Getting job result for job {job_id}")
def wait_for_task_result(self, task_id: str) -> Dict[str, Any]:
"""Poll for task completion and return results."""
self.logger.info(f"Getting task result for task {task_id}")

url = f"{self.config.base_url}/job/{job_id}"
# Use the list endpoint which has fresh data, not the individual task endpoint which caches
url = f"{self.config.base_url}/tasks"
start_time = time.time()
last_status_log = start_time
timeout_seconds = self.config.timeout_minutes * 60

# Add initial retry logic for 404 errors (job ID not yet available)
# Add initial retry logic for when task is not found yet
initial_retry_count = 0
max_initial_retries = 5
initial_retry_delay = 2 # seconds
Expand All @@ -94,28 +100,57 @@ def wait_for_job_result(self, job_id: str) -> Dict[str, Any]:
try:
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
job_data = response.json()["job"]
tasks_list = response.json()

# Find our specific task in the list
task_data = None
for task in tasks_list:
if task.get("id") == task_id:
task_data = task
break

if not task_data:
# Task not found in list yet
if initial_retry_count < max_initial_retries:
initial_retry_count += 1
self.logger.info(f"Task not found in list yet (attempt {initial_retry_count}/{max_initial_retries}), retrying in {initial_retry_delay} seconds...")
time.sleep(initial_retry_delay)
continue
else:
raise TranscriptionError(f"Task {task_id} not found in task list after {max_initial_retries} retries")

# Log the full response for debugging
self.logger.debug(f"Task status response: {task_data}")

if job_data["status"] == "completed":
return job_data
elif job_data["status"] == "failed":
raise TranscriptionError(f"Job failed: {job_data.get('error', 'Unknown error')}")
# Check status of targets (not the task itself)
targets = task_data.get("targets", [])
if not targets:
raise TranscriptionError("No targets found in task response")

# Check if all targets are completed or if any failed
all_completed = True
for target in targets:
target_status = target.get("status")
target_model = target.get("model")
self.logger.debug(f"Target {target_model} status: {target_status}")

if target_status == "failed":
error_msg = target.get("error", "Unknown error")
raise TranscriptionError(f"Target {target_model} failed: {error_msg}")
elif target_status != "completed":
all_completed = False

if all_completed:
self.logger.info("All targets completed successfully")
return task_data

# Reset retry count on successful response
initial_retry_count = 0

except requests.exceptions.HTTPError as e:
if e.response.status_code == 404 and initial_retry_count < max_initial_retries:
# Job ID not yet available, retry with delay
initial_retry_count += 1
self.logger.info(f"Job ID not yet available (attempt {initial_retry_count}/{max_initial_retries}), retrying in {initial_retry_delay} seconds...")
time.sleep(initial_retry_delay)
continue
else:
# Re-raise the error if it's not a 404 or we've exceeded retries
raise
raise

time.sleep(5) # Wait before next poll
time.sleep(30) # Wait before next poll


class AudioShakeTranscriber(BaseTranscriber):
Expand All @@ -142,13 +177,13 @@ def _perform_transcription(self, audio_filepath: str) -> TranscriptionData:
self.logger.info(f"Starting transcription for {audio_filepath}")

try:
# Start job and get results
# Start task and get results
self.logger.debug("Calling start_transcription()")
job_id = self.start_transcription(audio_filepath)
self.logger.debug(f"Got job_id: {job_id}")
task_id = self.start_transcription(audio_filepath)
self.logger.debug(f"Got task_id: {task_id}")

self.logger.debug("Calling get_transcription_result()")
result = self.get_transcription_result(job_id)
result = self.get_transcription_result(task_id)
self.logger.debug("Got transcription result")

return result
Expand All @@ -157,46 +192,61 @@ def _perform_transcription(self, audio_filepath: str) -> TranscriptionData:
raise

def start_transcription(self, audio_filepath: str) -> str:
"""Starts the transcription job and returns the job ID."""
"""Starts the transcription task and returns the task ID."""
self.logger.debug(f"Entering start_transcription() for {audio_filepath}")

# Upload file and create job
asset_id = self.api.upload_file(audio_filepath)
self.logger.debug(f"File uploaded successfully. Asset ID: {asset_id}")
# Upload file and create task
file_url = self.api.upload_file(audio_filepath)
self.logger.debug(f"File uploaded successfully. File URL: {file_url}")

job_id = self.api.create_job(asset_id)
self.logger.debug(f"Job created successfully. Job ID: {job_id}")
task_id = self.api.create_task(file_url)
self.logger.debug(f"Task created successfully. Task ID: {task_id}")

return job_id
return task_id

def get_transcription_result(self, job_id: str) -> Dict[str, Any]:
"""Gets the raw results for a previously started job."""
self.logger.debug(f"Entering get_transcription_result() for job ID: {job_id}")
def get_transcription_result(self, task_id: str) -> Dict[str, Any]:
"""Gets the raw results for a previously started task."""
self.logger.debug(f"Entering get_transcription_result() for task ID: {task_id}")

# Wait for job completion
job_data = self.api.wait_for_job_result(job_id)
self.logger.debug("Job completed. Getting results...")
# Wait for task completion
task_data = self.api.wait_for_task_result(task_id)
self.logger.debug("Task completed. Getting results...")

output_asset = next((asset for asset in job_data.get("outputAssets", []) if asset["name"] == "alignment.json"), None)
if not output_asset:
raise TranscriptionError("Required output not found in job results")
# Find the alignment target output
alignment_target = None
for target in task_data.get("targets", []):
if target.get("model") == "alignment":
alignment_target = target
break

if not alignment_target:
raise TranscriptionError("Required output not found in task results")

# Get the output file URL
output = alignment_target.get("output", [])
if not output:
raise TranscriptionError("No output found in alignment target")

output_url = output[0].get("link")
if not output_url:
raise TranscriptionError("Output link not found in alignment target")

# Fetch transcription data
response = requests.get(output_asset["link"])
response = requests.get(output_url)
response.raise_for_status()

# Return combined raw data
raw_data = {"job_data": job_data, "transcription": response.json()}
raw_data = {"task_data": task_data, "transcription": response.json()}

self.logger.debug("Raw results retrieved successfully")
return raw_data

def _convert_result_format(self, raw_data: Dict[str, Any]) -> TranscriptionData:
"""Process raw Audioshake API response into standard format."""
self.logger.debug(f"Processing result for job {raw_data['job_data']['id']}")
self.logger.debug(f"Processing result for task {raw_data['task_data']['id']}")

transcription_data = raw_data["transcription"]
job_data = raw_data["job_data"]
task_data = raw_data["task_data"]

segments = []
all_words = [] # Collect all words across segments
Expand Down Expand Up @@ -230,8 +280,8 @@ def _convert_result_format(self, raw_data: Dict[str, Any]) -> TranscriptionData:
source=self.get_name(),
metadata={
"language": transcription_data.get("metadata", {}).get("language"),
"job_id": job_data["id"],
"duration": job_data.get("statusInfo", {}).get("duration"),
"task_id": task_data["id"],
"duration": task_data.get("duration"),
},
)

Expand Down
Loading
Loading