Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 130 additions & 8 deletions backend/app/services/music_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@
SETTINGS_FILE = os.path.join(os.environ.get("HEARTMULA_DB_PATH", _default_db_dir).replace("jobs.db", ""), "settings.json")


def is_mps_available() -> bool:
"""
Check if Apple Metal Performance Shaders (MPS) is available.

Returns:
bool: True if MPS is available, False otherwise
"""
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()


def detect_optimal_gpu_config() -> dict:
"""
Auto-detect the optimal GPU configuration based on available VRAM.
Expand All @@ -94,6 +104,7 @@ def detect_optimal_gpu_config() -> dict:
- gpu_info: dict - info about each GPU (name, vram, compute capability)
- config_name: str - human-readable name of the selected configuration
- warning: str or None - any warnings about the configuration
- device_type: str - type of device to use ("cuda", "mps", or "cpu")
"""
result = {
"use_quantization": True, # Default to quantization for safety
Expand All @@ -102,10 +113,38 @@ def detect_optimal_gpu_config() -> dict:
"gpu_info": {},
"config_name": "CPU Only",
"warning": None,
"device_type": "cpu",
}

if not torch.cuda.is_available():
result["warning"] = "No CUDA GPU detected. Running on CPU will be very slow."
# Check for CUDA GPUs first
if torch.cuda.is_available():
result["device_type"] = "cuda"
# Continue with existing CUDA logic below
# Check for Apple Metal (MPS) on macOS
elif is_mps_available():
result["device_type"] = "mps"
result["num_gpus"] = 1
result["use_quantization"] = False # MPS works better with full precision
result["use_sequential_offload"] = False # Unified memory architecture
result["config_name"] = "Apple Metal (MPS)"
result["gpu_info"] = {
0: {
"name": "Apple Metal Performance Shaders",
"vram_gb": "Unified Memory",
"compute_capability": "MPS",
"supports_flash_attention": False,
}
}
print(f"\n[Auto-Config] Using Apple Metal (MPS) device", flush=True)
print(f"[Auto-Config] MPS uses unified memory - no VRAM limits", flush=True)
return result
# No GPU available - fall back to CPU
else:
result["warning"] = "No CUDA GPU or Metal GPU detected. Running on CPU will be very slow."
return result

# Continue with CUDA-specific logic if CUDA is available
if result["device_type"] != "cuda":
return result

num_gpus = torch.cuda.device_count()
Expand Down Expand Up @@ -433,6 +472,11 @@ def configure_flash_attention_for_gpu(device_id: int):
- NVIDIA SM 6.x and older: Disables Flash Attention, uses math backend
- AMD ROCm: Conservatively disables Flash Attention (compatibility varies)
"""
# Skip if MPS is being used
if is_mps_available() and not torch.cuda.is_available():
logger.info("[GPU Config] Apple Metal (MPS) device - skipping Flash Attention configuration")
return

if not torch.cuda.is_available():
logger.info("[GPU Config] CUDA not available - skipping Flash Attention configuration")
return
Expand Down Expand Up @@ -662,9 +706,12 @@ def _pad_audio_token(token):
pipeline.mula.reset_caches()
pipeline._mula.to("cpu")
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"[Sequential Offload] VRAM after offload: {torch.cuda.memory_allocated()/1024**3:.2f}GB", flush=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"[Sequential Offload] VRAM after offload: {torch.cuda.memory_allocated()/1024**3:.2f}GB", flush=True)
elif is_mps_available():
torch.mps.empty_cache()
else:
pipeline._unload()

Expand All @@ -683,7 +730,8 @@ def _pad_audio_token(token):
device_map=pipeline.codec_device,
dtype=torch.float32,
)
print(f"[Lazy Loading] HeartCodec loaded. VRAM: {torch.cuda.memory_allocated()/1024**3:.2f}GB", flush=True)
if torch.cuda.is_available():
print(f"[Lazy Loading] HeartCodec loaded. VRAM: {torch.cuda.memory_allocated()/1024**3:.2f}GB", flush=True)
else:
raise RuntimeError("Cannot load HeartCodec: codec_path not available")

Expand All @@ -696,8 +744,11 @@ def _pad_audio_token(token):
del pipeline._codec
pipeline._codec = None
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif is_mps_available():
torch.mps.empty_cache()

if pipeline._sequential_offload:
# Move HeartMuLa back to GPU for next generation
Expand Down Expand Up @@ -726,6 +777,9 @@ def cleanup_gpu_memory():
torch.cuda.empty_cache()
torch.cuda.synchronize()
logger.info("GPU memory cleaned up")
elif is_mps_available():
torch.mps.empty_cache()
logger.info("MPS memory cleaned up")


def get_gpu_memory(device_id):
Expand Down Expand Up @@ -1036,18 +1090,34 @@ def _unload_all_models(self):
with torch.cuda.device(i):
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif is_mps_available():
# MPS memory cleanup
torch.mps.empty_cache()

logger.info("All models unloaded")

def get_gpu_info(self) -> dict:
"""Get GPU hardware information."""
result = {
"cuda_available": torch.cuda.is_available(),
"mps_available": is_mps_available(),
"num_gpus": 0,
"gpus": [],
"total_vram_gb": 0
}

# Check for MPS (Apple Metal) first
if result["mps_available"] and not result["cuda_available"]:
result["num_gpus"] = 1
result["gpus"].append({
"index": 0,
"name": "Apple Metal Performance Shaders",
"vram_gb": "Unified Memory",
"compute_capability": "MPS",
"supports_flash_attention": False
})
return result

if not torch.cuda.is_available():
return result

Expand Down Expand Up @@ -1157,8 +1227,58 @@ def _load_pipeline_multi_gpu(self, model_path: str, version: str):
# Store the detected config for reference
self.gpu_config = auto_config

device_type = auto_config.get("device_type", "cpu")
num_gpus = auto_config["num_gpus"]

# Handle Apple Metal (MPS) devices
if device_type == "mps":
logger.info("Using Apple Metal (MPS) for GPU acceleration")
self.gpu_mode = "single"
print("[Apple Metal] Using MPS device for inference", flush=True)
print("[Apple Metal] Note: MPS uses unified memory architecture", flush=True)

# Check if quantization is manually enabled
if use_quantization:
logger.warning("4-bit quantization is not supported on MPS. Using full precision instead.")
print("[Apple Metal] WARNING: 4-bit quantization not supported on MPS, using full precision", flush=True)

# MPS doesn't support bfloat16, use float32 instead
pipeline = HeartMuLaGenPipeline.from_pretrained(
model_path,
device={
"mula": torch.device("mps"),
"codec": torch.device("mps"),
},
dtype={
"mula": torch.float32,
"codec": torch.float32,
},
version=version,
)
return patch_pipeline_with_callback(pipeline, sequential_offload=False)

# Handle CPU-only mode (no CUDA or MPS available)
if device_type == "cpu":
logger.warning("No GPU detected - running on CPU will be very slow")
self.gpu_mode = "cpu"
print("[CPU Mode] No GPU detected, using CPU for inference", flush=True)
print("[CPU Mode] WARNING: This will be extremely slow. Consider using a system with GPU support.", flush=True)

pipeline = HeartMuLaGenPipeline.from_pretrained(
model_path,
device={
"mula": torch.device("cpu"),
"codec": torch.device("cpu"),
},
dtype={
"mula": torch.float32,
"codec": torch.float32,
},
version=version,
)
return patch_pipeline_with_callback(pipeline, sequential_offload=False)

# At this point, device_type must be "cuda"
if use_quantization:
print(f"[Quantization] 4-bit quantization ENABLED - model will use ~3GB instead of ~11GB", flush=True)
else:
Expand Down Expand Up @@ -1640,6 +1760,8 @@ async def update_compile_progress():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif is_mps_available():
torch.mps.empty_cache()
logger.info("GPU memory cleaned up after generation")
except Exception as cleanup_err:
logger.warning(f"Memory cleanup warning: {cleanup_err}")
Expand Down