diff --git a/src/api/transcriptions.py b/src/api/transcriptions.py index 5db8e4b..7b4c228 100644 --- a/src/api/transcriptions.py +++ b/src/api/transcriptions.py @@ -6,7 +6,14 @@ import os import hashlib from src.config import settings -from src.services.whisper_service import transcribe_audio, get_model_status +from src.services.whisper_service import ( + transcribe_audio, + get_model_status, + get_available_models, + switch_model, + reload_model, + delete_model +) from src.services.stats_service import log_usage from src.database.db import get_db from sqlalchemy.orm import Session @@ -97,6 +104,55 @@ async def model_status_endpoint(api_key: str = Depends(verify_api_key)): return get_model_status() +@router.get("/available-models") +async def list_available_models(api_key: str = Depends(verify_api_key)): + """List all available Whisper models with download status""" + return { + "models": get_available_models() + } + + +@router.post("/switch-model") +async def switch_model_endpoint( + model: str = Form(...), + api_key: str = Depends(verify_api_key) +): + """Switch to a different Whisper model""" + try: + result = switch_model(model) + return result + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/reload-model") +async def reload_model_endpoint(api_key: str = Depends(verify_api_key)): + """Reload current model (re-download)""" + try: + # This will run in background to not block the API + result = reload_model() + return result + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/delete-model/{model_name}") +async def delete_model_endpoint( + model_name: str, + api_key: str = Depends(verify_api_key) +): + """Delete a downloaded model""" + try: + result = delete_model(model_name) + return result + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/audio/transcriptions") async def create_transcription( file: UploadFile = File(...), diff --git a/src/services/whisper_service.py b/src/services/whisper_service.py index 36339d5..85cd4b6 100644 --- a/src/services/whisper_service.py +++ b/src/services/whisper_service.py @@ -1,16 +1,18 @@ import whisper import torch import os -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List import asyncio from concurrent.futures import ThreadPoolExecutor import threading import time +import shutil from src.config import settings # Global model cache _model = None +_current_model_name = settings.whisper_model _executor = ThreadPoolExecutor(max_workers=1) _model_lock = threading.Lock() @@ -25,6 +27,18 @@ _model_status = { "is_loaded": False } +# Available Whisper models with their sizes (approximate) +AVAILABLE_MODELS = { + "tiny": {"size": "39 MB", "description": "Fastest, lowest accuracy", "english_only": False}, + "base": {"size": "74 MB", "description": "Fast, good for testing", "english_only": False}, + "small": {"size": "244 MB", "description": "Balanced speed/accuracy", "english_only": False}, + "medium": {"size": "769 MB", "description": "Good accuracy", "english_only": False}, + "large-v1": {"size": "1.55 GB", "description": "High accuracy (legacy)", "english_only": False}, + "large-v2": {"size": "2.87 GB", "description": "Higher accuracy", "english_only": False}, + "large-v3": {"size": "2.88 GB", "description": "Best accuracy", "english_only": False}, + "large": {"size": "2.88 GB", "description": "Alias for large-v3", "english_only": False}, +} + def _download_hook(progress_bytes, total_bytes): """Hook to track download progress""" @@ -37,28 +51,39 @@ def _download_hook(progress_bytes, total_bytes): _model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%" -def load_model(): +def load_model(model_name: str = None): """Load Whisper model""" - global _model, _model_status + global _model, _model_status, _current_model_name + + if model_name is None: + model_name = settings.whisper_model with _model_lock: + # If a different model is loaded, unload it first + if _model is not None and _current_model_name != model_name: + print(f"Switching from {_current_model_name} to {model_name}") + _model = None + torch.cuda.empty_cache() if torch.cuda.is_available() else None + if _model is None: _model_status["is_downloading"] = True _model_status["status_message"] = "Starting download..." + _model_status["model_name"] = model_name - print(f"Loading Whisper model: {settings.whisper_model}") + print(f"Loading Whisper model: {model_name}") try: # Whisper doesn't have a direct progress callback, but we can monitor the models directory _model = whisper.load_model( - settings.whisper_model, + model_name, device=settings.whisper_device, download_root=settings.models_path ) + _current_model_name = model_name _model_status["is_downloading"] = False _model_status["is_loaded"] = True _model_status["download_percentage"] = 100 _model_status["status_message"] = "Model loaded successfully" - print(f"Model loaded on {settings.whisper_device}") + print(f"Model {model_name} loaded on {settings.whisper_device}") except Exception as e: _model_status["is_downloading"] = False _model_status["status_message"] = f"Error: {str(e)}" @@ -69,7 +94,7 @@ def load_model(): def get_model_info(): """Get model information""" - global _model_status + global _model_status, _current_model_name # Check if model files exist in the models directory model_files = [] @@ -87,15 +112,23 @@ def get_model_info(): except: pass - # large-v3 is approximately 2.9GB - expected_size = 2.9 * 1024 * 1024 * 1024 # 2.9 GB in bytes + # Get expected size for current model + expected_size_gb = 2.9 # default to large-v3 + if _current_model_name in AVAILABLE_MODELS: + size_str = AVAILABLE_MODELS[_current_model_name]["size"] + if "GB" in size_str: + expected_size_gb = float(size_str.replace(" GB", "")) + elif "MB" in size_str: + expected_size_gb = float(size_str.replace(" MB", "")) / 1024 + + expected_size = expected_size_gb * 1024 * 1024 * 1024 if total_size > 0: estimated_percentage = min(99, round((total_size / expected_size) * 100, 2)) _model_status["download_percentage"] = estimated_percentage _model_status["status_message"] = f"Downloading: {estimated_percentage}%" return { - "name": settings.whisper_model, + "name": _current_model_name, "device": settings.whisper_device, "loaded": _model is not None, "is_downloading": _model_status["is_downloading"], @@ -110,6 +143,94 @@ def get_model_status(): return get_model_info() +def get_available_models(): + """Get list of available models with their download status""" + models_dir = settings.models_path + downloaded_models = set() + + if os.path.exists(models_dir): + for f in os.listdir(models_dir): + if f.endswith('.pt'): + # Extract model name from file (e.g., "large-v3.pt" -> "large-v3") + model_name = f.replace('.pt', '') + downloaded_models.add(model_name) + + result = [] + for model_name, info in AVAILABLE_MODELS.items(): + result.append({ + "name": model_name, + "size": info["size"], + "description": info["description"], + "english_only": info["english_only"], + "is_downloaded": model_name in downloaded_models, + "is_active": model_name == _current_model_name and _model is not None + }) + + return result + + +def switch_model(model_name: str): + """Switch to a different model""" + global _model, _current_model_name + + if model_name not in AVAILABLE_MODELS: + raise ValueError(f"Unknown model: {model_name}") + + if model_name == _current_model_name and _model is not None: + return {"status": "already_active", "message": f"Model {model_name} is already active"} + + # Load the new model + load_model(model_name) + return {"status": "success", "message": f"Switched to model {model_name}"} + + +def delete_model(model_name: str): + """Delete a downloaded model""" + global _model, _current_model_name + + if model_name not in AVAILABLE_MODELS: + raise ValueError(f"Unknown model: {model_name}") + + models_dir = settings.models_path + model_file = os.path.join(models_dir, f"{model_name}.pt") + + if not os.path.exists(model_file): + return {"status": "not_found", "message": f"Model {model_name} is not downloaded"} + + # If this model is currently loaded, unload it first + if model_name == _current_model_name and _model is not None: + _model = None + torch.cuda.empty_cache() if torch.cuda.is_available() else None + _current_model_name = None + + # Delete the file + os.remove(model_file) + return {"status": "success", "message": f"Model {model_name} deleted"} + + +def reload_model(): + """Reload current model (useful for re-downloading)""" + global _model + + with _model_lock: + model_name = _current_model_name if _current_model_name else settings.whisper_model + + # Unload current model + if _model is not None: + _model = None + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Delete model file if exists + models_dir = settings.models_path + model_file = os.path.join(models_dir, f"{model_name}.pt") + if os.path.exists(model_file): + os.remove(model_file) + + # Reload model (will trigger download) + load_model(model_name) + return {"status": "success", "message": f"Model {model_name} reloaded"} + + def _transcribe_sync( audio_path: str, language: Optional[str] = None, diff --git a/src/templates/dashboard.html b/src/templates/dashboard.html index 9ce8e05..876d595 100644 --- a/src/templates/dashboard.html +++ b/src/templates/dashboard.html @@ -9,6 +9,7 @@ diff --git a/src/templates/keys.html b/src/templates/keys.html index 8658e93..7d3c25b 100644 --- a/src/templates/keys.html +++ b/src/templates/keys.html @@ -9,6 +9,7 @@ diff --git a/src/templates/models.html b/src/templates/models.html new file mode 100644 index 0000000..330755a --- /dev/null +++ b/src/templates/models.html @@ -0,0 +1,124 @@ +{% extends "base.html" %} + +{% block title %}Models - Whisper API Admin{% endblock %} + +{% block content %} +
+
+

πŸ€– Model Management

+ +
+ + +
+

Current Model Status

+
+

Active Model: {{ current_status.name }}

+

Device: {{ current_status.device }}

+

Status: + {% if current_status.loaded %} + βœ… Loaded & Ready + {% elif current_status.is_downloading %} + ⏳ Downloading ({{ current_status.download_percentage }}%) + {% else %} + ⏸️ Not Loaded + {% endif %} +

+ + {% if current_status.is_downloading %} +
+
+
+ {% endif %} +
+
+ + +
+

Available Models

+ + + + + + + + + + + + {% for model in models %} + + + + + + + + {% endfor %} + +
ModelSizeDescriptionStatusActions
{{ model.name }}{{ model.size }}{{ model.description }} + {% if model.is_active %} + Active + {% elif model.is_downloaded %} + Downloaded + {% else %} + Not Downloaded + {% endif %} + + {% if not model.is_active %} +
+ + +
+ {% endif %} + + {% if model.is_downloaded %} +
+ + +
+ {% else %} +
+ + +
+ {% endif %} +
+
+ + +
+

Reload Current Model

+

If you experience issues with the current model, you can reload it. This will delete and re-download the model files.

+
+ +
+
+
+ + +{% endblock %} diff --git a/src/web/routes.py b/src/web/routes.py index 8244297..d82a3fd 100644 --- a/src/web/routes.py +++ b/src/web/routes.py @@ -11,7 +11,7 @@ from src.config import settings from src.database.db import get_db from src.database.models import ApiKey, UsageLog from src.services.stats_service import get_usage_stats, hash_api_key -from src.services.whisper_service import get_model_status +from src.services.whisper_service import get_model_status, get_available_models router = APIRouter() templates = Jinja2Templates(directory="src/templates") @@ -188,3 +188,131 @@ async def delete_key( db.commit() return RedirectResponse(url="/admin/keys", status_code=302) + + +@router.get("/models", response_class=HTMLResponse) +async def manage_models(request: Request, message: Optional[str] = None, error: Optional[str] = None): + """Model management page""" + try: + check_admin_auth(request) + except HTTPException as e: + return RedirectResponse(url="/admin/login", status_code=302) + + models = get_available_models() + current_status = get_model_status() + + return templates.TemplateResponse("models.html", { + "request": request, + "models": models, + "current_status": current_status, + "message": message, + "error": error + }) + + +@router.post("/models/switch") +async def switch_model_admin( + request: Request, + model_name: str = Form(...) +): + """Switch to a different model""" + try: + check_admin_auth(request) + except HTTPException as e: + return RedirectResponse(url="/admin/login", status_code=302) + + try: + from src.services.whisper_service import switch_model + result = switch_model(model_name) + return RedirectResponse( + url=f"/admin/models?message={result['message']}", + status_code=302 + ) + except Exception as e: + return RedirectResponse( + url=f"/admin/models?error={str(e)}", + status_code=302 + ) + + +@router.post("/models/delete") +async def delete_model_admin( + request: Request, + model_name: str = Form(...) +): + """Delete a model""" + try: + check_admin_auth(request) + except HTTPException as e: + return RedirectResponse(url="/admin/login", status_code=302) + + try: + from src.services.whisper_service import delete_model + result = delete_model(model_name) + return RedirectResponse( + url=f"/admin/models?message={result['message']}", + status_code=302 + ) + except Exception as e: + return RedirectResponse( + url=f"/admin/models?error={str(e)}", + status_code=302 + ) + + +@router.post("/models/download") +async def download_model_admin( + request: Request, + model_name: str = Form(...) +): + """Download a model""" + try: + check_admin_auth(request) + except HTTPException as e: + return RedirectResponse(url="/admin/login", status_code=302) + + try: + from src.services.whisper_service import load_model + # Start download in background (non-blocking for API) + import threading + def download(): + try: + load_model(model_name) + except Exception as e: + print(f"Error downloading model {model_name}: {e}") + + thread = threading.Thread(target=download) + thread.daemon = True + thread.start() + + return RedirectResponse( + url=f"/admin/models?message=Started downloading model {model_name}", + status_code=302 + ) + except Exception as e: + return RedirectResponse( + url=f"/admin/models?error={str(e)}", + status_code=302 + ) + + +@router.post("/models/reload") +async def reload_model_admin(request: Request): + """Reload current model""" + try: + check_admin_auth(request) + except HTTPException as e: + return RedirectResponse(url="/admin/login", status_code=302) + + try: + from src.services.whisper_service import reload_model + result = reload_model() + return RedirectResponse( + url=f"/admin/models?message={result['message']}", + status_code=302 + ) + except Exception as e: + return RedirectResponse( + url=f"/admin/models?error={str(e)}", + status_code=302 + )