diff --git a/src/api/transcriptions.py b/src/api/transcriptions.py index 15ea115..5db8e4b 100644 --- a/src/api/transcriptions.py +++ b/src/api/transcriptions.py @@ -6,7 +6,7 @@ import os import hashlib from src.config import settings -from src.services.whisper_service import transcribe_audio +from src.services.whisper_service import transcribe_audio, get_model_status from src.services.stats_service import log_usage from src.database.db import get_db from sqlalchemy.orm import Session @@ -18,6 +18,21 @@ from src.services.stats_service import hash_api_key from src.database.db import SessionLocal from src.database.models import ApiKey + +def check_model_loaded(): + """Check if model is loaded, raise exception if not""" + status = get_model_status() + if status.get("is_downloading") and not status.get("loaded"): + raise HTTPException( + status_code=503, + detail={ + "error": "Model not ready", + "message": f"Whisper model is still downloading: {status.get('download_percentage', 0)}% complete", + "status": status + } + ) + return status + def verify_api_key(authorization: Optional[str] = Header(None)): """Verify API key from Authorization header""" if not authorization: @@ -54,6 +69,9 @@ def verify_api_key(authorization: Optional[str] = Header(None)): @router.get("/models") async def list_models(api_key: str = Depends(verify_api_key)): """List available models (OpenAI compatible)""" + # Check model status and include it in response + model_status = get_model_status() + return { "data": [ { @@ -68,10 +86,17 @@ async def list_models(api_key: str = Depends(verify_api_key)): "created": 1698796800, "owned_by": "openai" } - ] + ], + "model_status": model_status } +@router.get("/model-status") +async def model_status_endpoint(api_key: str = Depends(verify_api_key)): + """Get current model download/load status""" + return get_model_status() + + @router.post("/audio/transcriptions") async def create_transcription( file: UploadFile = File(...), @@ -94,6 +119,9 @@ async def create_transcription( - **timestamp_granularities**: word, segment (for verbose_json) """ + # Check if model is loaded first + check_model_loaded() + start_time = time.time() temp_path = None diff --git a/src/services/whisper_service.py b/src/services/whisper_service.py index c84a8c1..36339d5 100644 --- a/src/services/whisper_service.py +++ b/src/services/whisper_service.py @@ -4,38 +4,112 @@ import os from typing import Optional, Dict, Any import asyncio from concurrent.futures import ThreadPoolExecutor +import threading +import time from src.config import settings # Global model cache _model = None _executor = ThreadPoolExecutor(max_workers=1) +_model_lock = threading.Lock() + +# Model download status +_model_status = { + "is_downloading": False, + "download_progress": 0, + "download_total": 0, + "download_percentage": 0, + "status_message": "Not started", + "model_name": settings.whisper_model, + "is_loaded": False +} + + +def _download_hook(progress_bytes, total_bytes): + """Hook to track download progress""" + global _model_status + _model_status["is_downloading"] = True + _model_status["download_progress"] = progress_bytes + _model_status["download_total"] = total_bytes + if total_bytes > 0: + _model_status["download_percentage"] = round((progress_bytes / total_bytes) * 100, 2) + _model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%" def load_model(): """Load Whisper model""" - global _model - if _model is None: - print(f"Loading Whisper model: {settings.whisper_model}") - _model = whisper.load_model( - settings.whisper_model, - device=settings.whisper_device, - download_root=settings.models_path - ) - print(f"Model loaded on {settings.whisper_device}") + global _model, _model_status + + with _model_lock: + if _model is None: + _model_status["is_downloading"] = True + _model_status["status_message"] = "Starting download..." + + print(f"Loading Whisper model: {settings.whisper_model}") + try: + # Whisper doesn't have a direct progress callback, but we can monitor the models directory + _model = whisper.load_model( + settings.whisper_model, + device=settings.whisper_device, + download_root=settings.models_path + ) + _model_status["is_downloading"] = False + _model_status["is_loaded"] = True + _model_status["download_percentage"] = 100 + _model_status["status_message"] = "Model loaded successfully" + print(f"Model loaded on {settings.whisper_device}") + except Exception as e: + _model_status["is_downloading"] = False + _model_status["status_message"] = f"Error: {str(e)}" + raise + return _model def get_model_info(): """Get model information""" - model = load_model() + global _model_status + + # Check if model files exist in the models directory + model_files = [] + models_dir = settings.models_path + if os.path.exists(models_dir): + model_files = [f for f in os.listdir(models_dir) if f.endswith('.pt') or f.endswith('.bin')] + + # Calculate approximate download progress based on file size if downloading + if _model_status["is_downloading"] and not _model_status["is_loaded"]: + # Try to estimate progress from existing files + total_size = 0 + for f in model_files: + try: + total_size += os.path.getsize(os.path.join(models_dir, f)) + except: + pass + + # large-v3 is approximately 2.9GB + expected_size = 2.9 * 1024 * 1024 * 1024 # 2.9 GB in bytes + if total_size > 0: + estimated_percentage = min(99, round((total_size / expected_size) * 100, 2)) + _model_status["download_percentage"] = estimated_percentage + _model_status["status_message"] = f"Downloading: {estimated_percentage}%" + return { "name": settings.whisper_model, "device": settings.whisper_device, - "loaded": _model is not None + "loaded": _model is not None, + "is_downloading": _model_status["is_downloading"], + "download_percentage": _model_status["download_percentage"], + "status_message": _model_status["status_message"], + "model_files": model_files } +def get_model_status(): + """Get current model download/load status""" + return get_model_info() + + def _transcribe_sync( audio_path: str, language: Optional[str] = None, diff --git a/src/templates/dashboard.html b/src/templates/dashboard.html index e32f320..9ce8e05 100644 --- a/src/templates/dashboard.html +++ b/src/templates/dashboard.html @@ -32,6 +32,39 @@ + +
Status: β Loaded & Ready
+Model: {{ model_status.name }}
+Device: {{ model_status.device }}
+ {% elif model_status.is_downloading %} +Status: β³ Downloading...
+{{ model_status.status_message }}
+ ++ {{ model_status.download_percentage }}% +
+ {% else %} +Status: βΈοΈ Not Started
+The model will be loaded on first transcription request.
+ {% endif %} + + {% if model_status.model_files %} +Model Files:
+