Add model management: switch, download, delete models via admin panel and API

This commit is contained in:
Dominic Ballenthin
2026-01-29 01:45:10 +01:00
parent 0f336428a0
commit ee9465f661
6 changed files with 443 additions and 12 deletions

View File

@@ -6,7 +6,14 @@ import os
import hashlib import hashlib
from src.config import settings from src.config import settings
from src.services.whisper_service import transcribe_audio, get_model_status from src.services.whisper_service import (
transcribe_audio,
get_model_status,
get_available_models,
switch_model,
reload_model,
delete_model
)
from src.services.stats_service import log_usage from src.services.stats_service import log_usage
from src.database.db import get_db from src.database.db import get_db
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -97,6 +104,55 @@ async def model_status_endpoint(api_key: str = Depends(verify_api_key)):
return get_model_status() return get_model_status()
@router.get("/available-models")
async def list_available_models(api_key: str = Depends(verify_api_key)):
"""List all available Whisper models with download status"""
return {
"models": get_available_models()
}
@router.post("/switch-model")
async def switch_model_endpoint(
model: str = Form(...),
api_key: str = Depends(verify_api_key)
):
"""Switch to a different Whisper model"""
try:
result = switch_model(model)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/reload-model")
async def reload_model_endpoint(api_key: str = Depends(verify_api_key)):
"""Reload current model (re-download)"""
try:
# This will run in background to not block the API
result = reload_model()
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/delete-model/{model_name}")
async def delete_model_endpoint(
model_name: str,
api_key: str = Depends(verify_api_key)
):
"""Delete a downloaded model"""
try:
result = delete_model(model_name)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/audio/transcriptions") @router.post("/audio/transcriptions")
async def create_transcription( async def create_transcription(
file: UploadFile = File(...), file: UploadFile = File(...),

View File

@@ -1,16 +1,18 @@
import whisper import whisper
import torch import torch
import os import os
from typing import Optional, Dict, Any from typing import Optional, Dict, Any, List
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import threading import threading
import time import time
import shutil
from src.config import settings from src.config import settings
# Global model cache # Global model cache
_model = None _model = None
_current_model_name = settings.whisper_model
_executor = ThreadPoolExecutor(max_workers=1) _executor = ThreadPoolExecutor(max_workers=1)
_model_lock = threading.Lock() _model_lock = threading.Lock()
@@ -25,6 +27,18 @@ _model_status = {
"is_loaded": False "is_loaded": False
} }
# Available Whisper models with their sizes (approximate)
AVAILABLE_MODELS = {
"tiny": {"size": "39 MB", "description": "Fastest, lowest accuracy", "english_only": False},
"base": {"size": "74 MB", "description": "Fast, good for testing", "english_only": False},
"small": {"size": "244 MB", "description": "Balanced speed/accuracy", "english_only": False},
"medium": {"size": "769 MB", "description": "Good accuracy", "english_only": False},
"large-v1": {"size": "1.55 GB", "description": "High accuracy (legacy)", "english_only": False},
"large-v2": {"size": "2.87 GB", "description": "Higher accuracy", "english_only": False},
"large-v3": {"size": "2.88 GB", "description": "Best accuracy", "english_only": False},
"large": {"size": "2.88 GB", "description": "Alias for large-v3", "english_only": False},
}
def _download_hook(progress_bytes, total_bytes): def _download_hook(progress_bytes, total_bytes):
"""Hook to track download progress""" """Hook to track download progress"""
@@ -37,28 +51,39 @@ def _download_hook(progress_bytes, total_bytes):
_model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%" _model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%"
def load_model(): def load_model(model_name: str = None):
"""Load Whisper model""" """Load Whisper model"""
global _model, _model_status global _model, _model_status, _current_model_name
if model_name is None:
model_name = settings.whisper_model
with _model_lock: with _model_lock:
# If a different model is loaded, unload it first
if _model is not None and _current_model_name != model_name:
print(f"Switching from {_current_model_name} to {model_name}")
_model = None
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if _model is None: if _model is None:
_model_status["is_downloading"] = True _model_status["is_downloading"] = True
_model_status["status_message"] = "Starting download..." _model_status["status_message"] = "Starting download..."
_model_status["model_name"] = model_name
print(f"Loading Whisper model: {settings.whisper_model}") print(f"Loading Whisper model: {model_name}")
try: try:
# Whisper doesn't have a direct progress callback, but we can monitor the models directory # Whisper doesn't have a direct progress callback, but we can monitor the models directory
_model = whisper.load_model( _model = whisper.load_model(
settings.whisper_model, model_name,
device=settings.whisper_device, device=settings.whisper_device,
download_root=settings.models_path download_root=settings.models_path
) )
_current_model_name = model_name
_model_status["is_downloading"] = False _model_status["is_downloading"] = False
_model_status["is_loaded"] = True _model_status["is_loaded"] = True
_model_status["download_percentage"] = 100 _model_status["download_percentage"] = 100
_model_status["status_message"] = "Model loaded successfully" _model_status["status_message"] = "Model loaded successfully"
print(f"Model loaded on {settings.whisper_device}") print(f"Model {model_name} loaded on {settings.whisper_device}")
except Exception as e: except Exception as e:
_model_status["is_downloading"] = False _model_status["is_downloading"] = False
_model_status["status_message"] = f"Error: {str(e)}" _model_status["status_message"] = f"Error: {str(e)}"
@@ -69,7 +94,7 @@ def load_model():
def get_model_info(): def get_model_info():
"""Get model information""" """Get model information"""
global _model_status global _model_status, _current_model_name
# Check if model files exist in the models directory # Check if model files exist in the models directory
model_files = [] model_files = []
@@ -87,15 +112,23 @@ def get_model_info():
except: except:
pass pass
# large-v3 is approximately 2.9GB # Get expected size for current model
expected_size = 2.9 * 1024 * 1024 * 1024 # 2.9 GB in bytes expected_size_gb = 2.9 # default to large-v3
if _current_model_name in AVAILABLE_MODELS:
size_str = AVAILABLE_MODELS[_current_model_name]["size"]
if "GB" in size_str:
expected_size_gb = float(size_str.replace(" GB", ""))
elif "MB" in size_str:
expected_size_gb = float(size_str.replace(" MB", "")) / 1024
expected_size = expected_size_gb * 1024 * 1024 * 1024
if total_size > 0: if total_size > 0:
estimated_percentage = min(99, round((total_size / expected_size) * 100, 2)) estimated_percentage = min(99, round((total_size / expected_size) * 100, 2))
_model_status["download_percentage"] = estimated_percentage _model_status["download_percentage"] = estimated_percentage
_model_status["status_message"] = f"Downloading: {estimated_percentage}%" _model_status["status_message"] = f"Downloading: {estimated_percentage}%"
return { return {
"name": settings.whisper_model, "name": _current_model_name,
"device": settings.whisper_device, "device": settings.whisper_device,
"loaded": _model is not None, "loaded": _model is not None,
"is_downloading": _model_status["is_downloading"], "is_downloading": _model_status["is_downloading"],
@@ -110,6 +143,94 @@ def get_model_status():
return get_model_info() return get_model_info()
def get_available_models():
"""Get list of available models with their download status"""
models_dir = settings.models_path
downloaded_models = set()
if os.path.exists(models_dir):
for f in os.listdir(models_dir):
if f.endswith('.pt'):
# Extract model name from file (e.g., "large-v3.pt" -> "large-v3")
model_name = f.replace('.pt', '')
downloaded_models.add(model_name)
result = []
for model_name, info in AVAILABLE_MODELS.items():
result.append({
"name": model_name,
"size": info["size"],
"description": info["description"],
"english_only": info["english_only"],
"is_downloaded": model_name in downloaded_models,
"is_active": model_name == _current_model_name and _model is not None
})
return result
def switch_model(model_name: str):
"""Switch to a different model"""
global _model, _current_model_name
if model_name not in AVAILABLE_MODELS:
raise ValueError(f"Unknown model: {model_name}")
if model_name == _current_model_name and _model is not None:
return {"status": "already_active", "message": f"Model {model_name} is already active"}
# Load the new model
load_model(model_name)
return {"status": "success", "message": f"Switched to model {model_name}"}
def delete_model(model_name: str):
"""Delete a downloaded model"""
global _model, _current_model_name
if model_name not in AVAILABLE_MODELS:
raise ValueError(f"Unknown model: {model_name}")
models_dir = settings.models_path
model_file = os.path.join(models_dir, f"{model_name}.pt")
if not os.path.exists(model_file):
return {"status": "not_found", "message": f"Model {model_name} is not downloaded"}
# If this model is currently loaded, unload it first
if model_name == _current_model_name and _model is not None:
_model = None
torch.cuda.empty_cache() if torch.cuda.is_available() else None
_current_model_name = None
# Delete the file
os.remove(model_file)
return {"status": "success", "message": f"Model {model_name} deleted"}
def reload_model():
"""Reload current model (useful for re-downloading)"""
global _model
with _model_lock:
model_name = _current_model_name if _current_model_name else settings.whisper_model
# Unload current model
if _model is not None:
_model = None
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Delete model file if exists
models_dir = settings.models_path
model_file = os.path.join(models_dir, f"{model_name}.pt")
if os.path.exists(model_file):
os.remove(model_file)
# Reload model (will trigger download)
load_model(model_name)
return {"status": "success", "message": f"Model {model_name} reloaded"}
def _transcribe_sync( def _transcribe_sync(
audio_path: str, audio_path: str,
language: Optional[str] = None, language: Optional[str] = None,

View File

@@ -9,6 +9,7 @@
<div class="nav"> <div class="nav">
<a href="/admin" class="active">Dashboard</a> <a href="/admin" class="active">Dashboard</a>
<a href="/admin/keys">API Keys</a> <a href="/admin/keys">API Keys</a>
<a href="/admin/models">Models</a>
<a href="/admin/logout">Logout</a> <a href="/admin/logout">Logout</a>
</div> </div>
</div> </div>

View File

@@ -9,6 +9,7 @@
<div class="nav"> <div class="nav">
<a href="/admin">Dashboard</a> <a href="/admin">Dashboard</a>
<a href="/admin/keys" class="active">API Keys</a> <a href="/admin/keys" class="active">API Keys</a>
<a href="/admin/models">Models</a>
<a href="/admin/logout">Logout</a> <a href="/admin/logout">Logout</a>
</div> </div>
</div> </div>

124
src/templates/models.html Normal file
View File

@@ -0,0 +1,124 @@
{% extends "base.html" %}
{% block title %}Models - Whisper API Admin{% endblock %}
{% block content %}
<div class="container">
<div class="header">
<h1>🤖 Model Management</h1>
<div class="nav">
<a href="/admin">Dashboard</a>
<a href="/admin/keys">API Keys</a>
<a href="/admin/models" class="active">Models</a>
<a href="/admin/logout">Logout</a>
</div>
</div>
<!-- Current Model Status -->
<div class="card">
<h2>Current Model Status</h2>
<div id="current-model-info">
<p><strong>Active Model:</strong> {{ current_status.name }}</p>
<p><strong>Device:</strong> {{ current_status.device }}</p>
<p><strong>Status:</strong>
{% if current_status.loaded %}
<span style="color: #48bb78;">✅ Loaded & Ready</span>
{% elif current_status.is_downloading %}
<span style="color: #ed8936;">⏳ Downloading ({{ current_status.download_percentage }}%)</span>
{% else %}
<span style="color: #718096;">⏸️ Not Loaded</span>
{% endif %}
</p>
{% if current_status.is_downloading %}
<div style="background: #e2e8f0; border-radius: 10px; height: 20px; overflow: hidden; margin-top: 10px;">
<div id="download-progress" style="background: linear-gradient(90deg, #667eea, #764ba2); height: 100%; width: {{ current_status.download_percentage }}%; transition: width 0.5s ease;"></div>
</div>
{% endif %}
</div>
</div>
<!-- Available Models -->
<div class="card">
<h2>Available Models</h2>
<table>
<thead>
<tr>
<th>Model</th>
<th>Size</th>
<th>Description</th>
<th>Status</th>
<th>Actions</th>
</tr>
</thead>
<tbody>
{% for model in models %}
<tr>
<td><strong>{{ model.name }}</strong></td>
<td>{{ model.size }}</td>
<td>{{ model.description }}</td>
<td>
{% if model.is_active %}
<span class="badge badge-success">Active</span>
{% elif model.is_downloaded %}
<span class="badge" style="background: #bee3f8; color: #2c5282;">Downloaded</span>
{% else %}
<span class="badge badge-danger">Not Downloaded</span>
{% endif %}
</td>
<td>
{% if not model.is_active %}
<form method="POST" action="/admin/models/switch" style="display: inline;">
<input type="hidden" name="model_name" value="{{ model.name }}">
<button type="submit" class="btn" style="padding: 5px 10px; font-size: 12px;"
{% if not model.is_downloaded %}disabled{% endif %}>
Activate
</button>
</form>
{% endif %}
{% if model.is_downloaded %}
<form method="POST" action="/admin/models/delete" style="display: inline; margin-left: 5px;"
onsubmit="return confirm('Are you sure you want to delete this model?');">
<input type="hidden" name="model_name" value="{{ model.name }}">
<button type="submit" class="btn btn-danger" style="padding: 5px 10px; font-size: 12px;"
{% if model.is_active %}disabled{% endif %}>
Delete
</button>
</form>
{% else %}
<form method="POST" action="/admin/models/download" style="display: inline; margin-left: 5px;">
<input type="hidden" name="model_name" value="{{ model.name }}">
<button type="submit" class="btn btn-success" style="padding: 5px 10px; font-size: 12px;">
Download
</button>
</form>
{% endif %}
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<!-- Reload Current Model -->
<div class="card">
<h2>Reload Current Model</h2>
<p>If you experience issues with the current model, you can reload it. This will delete and re-download the model files.</p>
<form method="POST" action="/admin/models/reload" onsubmit="return confirm('This will delete and re-download the current model. Continue?');">
<button type="submit" class="btn" style="background: #ed8936;">
🔄 Reload Model
</button>
</form>
</div>
</div>
<script>
// Auto-refresh page every 5 seconds if downloading
{% if current_status.is_downloading %}
setTimeout(function() {
window.location.reload();
}, 5000);
{% endif %}
</script>
{% endblock %}

View File

@@ -11,7 +11,7 @@ from src.config import settings
from src.database.db import get_db from src.database.db import get_db
from src.database.models import ApiKey, UsageLog from src.database.models import ApiKey, UsageLog
from src.services.stats_service import get_usage_stats, hash_api_key from src.services.stats_service import get_usage_stats, hash_api_key
from src.services.whisper_service import get_model_status from src.services.whisper_service import get_model_status, get_available_models
router = APIRouter() router = APIRouter()
templates = Jinja2Templates(directory="src/templates") templates = Jinja2Templates(directory="src/templates")
@@ -188,3 +188,131 @@ async def delete_key(
db.commit() db.commit()
return RedirectResponse(url="/admin/keys", status_code=302) return RedirectResponse(url="/admin/keys", status_code=302)
@router.get("/models", response_class=HTMLResponse)
async def manage_models(request: Request, message: Optional[str] = None, error: Optional[str] = None):
"""Model management page"""
try:
check_admin_auth(request)
except HTTPException as e:
return RedirectResponse(url="/admin/login", status_code=302)
models = get_available_models()
current_status = get_model_status()
return templates.TemplateResponse("models.html", {
"request": request,
"models": models,
"current_status": current_status,
"message": message,
"error": error
})
@router.post("/models/switch")
async def switch_model_admin(
request: Request,
model_name: str = Form(...)
):
"""Switch to a different model"""
try:
check_admin_auth(request)
except HTTPException as e:
return RedirectResponse(url="/admin/login", status_code=302)
try:
from src.services.whisper_service import switch_model
result = switch_model(model_name)
return RedirectResponse(
url=f"/admin/models?message={result['message']}",
status_code=302
)
except Exception as e:
return RedirectResponse(
url=f"/admin/models?error={str(e)}",
status_code=302
)
@router.post("/models/delete")
async def delete_model_admin(
request: Request,
model_name: str = Form(...)
):
"""Delete a model"""
try:
check_admin_auth(request)
except HTTPException as e:
return RedirectResponse(url="/admin/login", status_code=302)
try:
from src.services.whisper_service import delete_model
result = delete_model(model_name)
return RedirectResponse(
url=f"/admin/models?message={result['message']}",
status_code=302
)
except Exception as e:
return RedirectResponse(
url=f"/admin/models?error={str(e)}",
status_code=302
)
@router.post("/models/download")
async def download_model_admin(
request: Request,
model_name: str = Form(...)
):
"""Download a model"""
try:
check_admin_auth(request)
except HTTPException as e:
return RedirectResponse(url="/admin/login", status_code=302)
try:
from src.services.whisper_service import load_model
# Start download in background (non-blocking for API)
import threading
def download():
try:
load_model(model_name)
except Exception as e:
print(f"Error downloading model {model_name}: {e}")
thread = threading.Thread(target=download)
thread.daemon = True
thread.start()
return RedirectResponse(
url=f"/admin/models?message=Started downloading model {model_name}",
status_code=302
)
except Exception as e:
return RedirectResponse(
url=f"/admin/models?error={str(e)}",
status_code=302
)
@router.post("/models/reload")
async def reload_model_admin(request: Request):
"""Reload current model"""
try:
check_admin_auth(request)
except HTTPException as e:
return RedirectResponse(url="/admin/login", status_code=302)
try:
from src.services.whisper_service import reload_model
result = reload_model()
return RedirectResponse(
url=f"/admin/models?message={result['message']}",
status_code=302
)
except Exception as e:
return RedirectResponse(
url=f"/admin/models?error={str(e)}",
status_code=302
)