Add model management: switch, download, delete models via admin panel and API
This commit is contained in:
@@ -6,7 +6,14 @@ import os
|
||||
import hashlib
|
||||
|
||||
from src.config import settings
|
||||
from src.services.whisper_service import transcribe_audio, get_model_status
|
||||
from src.services.whisper_service import (
|
||||
transcribe_audio,
|
||||
get_model_status,
|
||||
get_available_models,
|
||||
switch_model,
|
||||
reload_model,
|
||||
delete_model
|
||||
)
|
||||
from src.services.stats_service import log_usage
|
||||
from src.database.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -97,6 +104,55 @@ async def model_status_endpoint(api_key: str = Depends(verify_api_key)):
|
||||
return get_model_status()
|
||||
|
||||
|
||||
@router.get("/available-models")
|
||||
async def list_available_models(api_key: str = Depends(verify_api_key)):
|
||||
"""List all available Whisper models with download status"""
|
||||
return {
|
||||
"models": get_available_models()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/switch-model")
|
||||
async def switch_model_endpoint(
|
||||
model: str = Form(...),
|
||||
api_key: str = Depends(verify_api_key)
|
||||
):
|
||||
"""Switch to a different Whisper model"""
|
||||
try:
|
||||
result = switch_model(model)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/reload-model")
|
||||
async def reload_model_endpoint(api_key: str = Depends(verify_api_key)):
|
||||
"""Reload current model (re-download)"""
|
||||
try:
|
||||
# This will run in background to not block the API
|
||||
result = reload_model()
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/delete-model/{model_name}")
|
||||
async def delete_model_endpoint(
|
||||
model_name: str,
|
||||
api_key: str = Depends(verify_api_key)
|
||||
):
|
||||
"""Delete a downloaded model"""
|
||||
try:
|
||||
result = delete_model(model_name)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/audio/transcriptions")
|
||||
async def create_transcription(
|
||||
file: UploadFile = File(...),
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import whisper
|
||||
import torch
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import time
|
||||
import shutil
|
||||
|
||||
from src.config import settings
|
||||
|
||||
# Global model cache
|
||||
_model = None
|
||||
_current_model_name = settings.whisper_model
|
||||
_executor = ThreadPoolExecutor(max_workers=1)
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
@@ -25,6 +27,18 @@ _model_status = {
|
||||
"is_loaded": False
|
||||
}
|
||||
|
||||
# Available Whisper models with their sizes (approximate)
|
||||
AVAILABLE_MODELS = {
|
||||
"tiny": {"size": "39 MB", "description": "Fastest, lowest accuracy", "english_only": False},
|
||||
"base": {"size": "74 MB", "description": "Fast, good for testing", "english_only": False},
|
||||
"small": {"size": "244 MB", "description": "Balanced speed/accuracy", "english_only": False},
|
||||
"medium": {"size": "769 MB", "description": "Good accuracy", "english_only": False},
|
||||
"large-v1": {"size": "1.55 GB", "description": "High accuracy (legacy)", "english_only": False},
|
||||
"large-v2": {"size": "2.87 GB", "description": "Higher accuracy", "english_only": False},
|
||||
"large-v3": {"size": "2.88 GB", "description": "Best accuracy", "english_only": False},
|
||||
"large": {"size": "2.88 GB", "description": "Alias for large-v3", "english_only": False},
|
||||
}
|
||||
|
||||
|
||||
def _download_hook(progress_bytes, total_bytes):
|
||||
"""Hook to track download progress"""
|
||||
@@ -37,28 +51,39 @@ def _download_hook(progress_bytes, total_bytes):
|
||||
_model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%"
|
||||
|
||||
|
||||
def load_model():
|
||||
def load_model(model_name: str = None):
|
||||
"""Load Whisper model"""
|
||||
global _model, _model_status
|
||||
global _model, _model_status, _current_model_name
|
||||
|
||||
if model_name is None:
|
||||
model_name = settings.whisper_model
|
||||
|
||||
with _model_lock:
|
||||
# If a different model is loaded, unload it first
|
||||
if _model is not None and _current_model_name != model_name:
|
||||
print(f"Switching from {_current_model_name} to {model_name}")
|
||||
_model = None
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
|
||||
if _model is None:
|
||||
_model_status["is_downloading"] = True
|
||||
_model_status["status_message"] = "Starting download..."
|
||||
_model_status["model_name"] = model_name
|
||||
|
||||
print(f"Loading Whisper model: {settings.whisper_model}")
|
||||
print(f"Loading Whisper model: {model_name}")
|
||||
try:
|
||||
# Whisper doesn't have a direct progress callback, but we can monitor the models directory
|
||||
_model = whisper.load_model(
|
||||
settings.whisper_model,
|
||||
model_name,
|
||||
device=settings.whisper_device,
|
||||
download_root=settings.models_path
|
||||
)
|
||||
_current_model_name = model_name
|
||||
_model_status["is_downloading"] = False
|
||||
_model_status["is_loaded"] = True
|
||||
_model_status["download_percentage"] = 100
|
||||
_model_status["status_message"] = "Model loaded successfully"
|
||||
print(f"Model loaded on {settings.whisper_device}")
|
||||
print(f"Model {model_name} loaded on {settings.whisper_device}")
|
||||
except Exception as e:
|
||||
_model_status["is_downloading"] = False
|
||||
_model_status["status_message"] = f"Error: {str(e)}"
|
||||
@@ -69,7 +94,7 @@ def load_model():
|
||||
|
||||
def get_model_info():
|
||||
"""Get model information"""
|
||||
global _model_status
|
||||
global _model_status, _current_model_name
|
||||
|
||||
# Check if model files exist in the models directory
|
||||
model_files = []
|
||||
@@ -87,15 +112,23 @@ def get_model_info():
|
||||
except:
|
||||
pass
|
||||
|
||||
# large-v3 is approximately 2.9GB
|
||||
expected_size = 2.9 * 1024 * 1024 * 1024 # 2.9 GB in bytes
|
||||
# Get expected size for current model
|
||||
expected_size_gb = 2.9 # default to large-v3
|
||||
if _current_model_name in AVAILABLE_MODELS:
|
||||
size_str = AVAILABLE_MODELS[_current_model_name]["size"]
|
||||
if "GB" in size_str:
|
||||
expected_size_gb = float(size_str.replace(" GB", ""))
|
||||
elif "MB" in size_str:
|
||||
expected_size_gb = float(size_str.replace(" MB", "")) / 1024
|
||||
|
||||
expected_size = expected_size_gb * 1024 * 1024 * 1024
|
||||
if total_size > 0:
|
||||
estimated_percentage = min(99, round((total_size / expected_size) * 100, 2))
|
||||
_model_status["download_percentage"] = estimated_percentage
|
||||
_model_status["status_message"] = f"Downloading: {estimated_percentage}%"
|
||||
|
||||
return {
|
||||
"name": settings.whisper_model,
|
||||
"name": _current_model_name,
|
||||
"device": settings.whisper_device,
|
||||
"loaded": _model is not None,
|
||||
"is_downloading": _model_status["is_downloading"],
|
||||
@@ -110,6 +143,94 @@ def get_model_status():
|
||||
return get_model_info()
|
||||
|
||||
|
||||
def get_available_models():
|
||||
"""Get list of available models with their download status"""
|
||||
models_dir = settings.models_path
|
||||
downloaded_models = set()
|
||||
|
||||
if os.path.exists(models_dir):
|
||||
for f in os.listdir(models_dir):
|
||||
if f.endswith('.pt'):
|
||||
# Extract model name from file (e.g., "large-v3.pt" -> "large-v3")
|
||||
model_name = f.replace('.pt', '')
|
||||
downloaded_models.add(model_name)
|
||||
|
||||
result = []
|
||||
for model_name, info in AVAILABLE_MODELS.items():
|
||||
result.append({
|
||||
"name": model_name,
|
||||
"size": info["size"],
|
||||
"description": info["description"],
|
||||
"english_only": info["english_only"],
|
||||
"is_downloaded": model_name in downloaded_models,
|
||||
"is_active": model_name == _current_model_name and _model is not None
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def switch_model(model_name: str):
|
||||
"""Switch to a different model"""
|
||||
global _model, _current_model_name
|
||||
|
||||
if model_name not in AVAILABLE_MODELS:
|
||||
raise ValueError(f"Unknown model: {model_name}")
|
||||
|
||||
if model_name == _current_model_name and _model is not None:
|
||||
return {"status": "already_active", "message": f"Model {model_name} is already active"}
|
||||
|
||||
# Load the new model
|
||||
load_model(model_name)
|
||||
return {"status": "success", "message": f"Switched to model {model_name}"}
|
||||
|
||||
|
||||
def delete_model(model_name: str):
|
||||
"""Delete a downloaded model"""
|
||||
global _model, _current_model_name
|
||||
|
||||
if model_name not in AVAILABLE_MODELS:
|
||||
raise ValueError(f"Unknown model: {model_name}")
|
||||
|
||||
models_dir = settings.models_path
|
||||
model_file = os.path.join(models_dir, f"{model_name}.pt")
|
||||
|
||||
if not os.path.exists(model_file):
|
||||
return {"status": "not_found", "message": f"Model {model_name} is not downloaded"}
|
||||
|
||||
# If this model is currently loaded, unload it first
|
||||
if model_name == _current_model_name and _model is not None:
|
||||
_model = None
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
_current_model_name = None
|
||||
|
||||
# Delete the file
|
||||
os.remove(model_file)
|
||||
return {"status": "success", "message": f"Model {model_name} deleted"}
|
||||
|
||||
|
||||
def reload_model():
|
||||
"""Reload current model (useful for re-downloading)"""
|
||||
global _model
|
||||
|
||||
with _model_lock:
|
||||
model_name = _current_model_name if _current_model_name else settings.whisper_model
|
||||
|
||||
# Unload current model
|
||||
if _model is not None:
|
||||
_model = None
|
||||
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||
|
||||
# Delete model file if exists
|
||||
models_dir = settings.models_path
|
||||
model_file = os.path.join(models_dir, f"{model_name}.pt")
|
||||
if os.path.exists(model_file):
|
||||
os.remove(model_file)
|
||||
|
||||
# Reload model (will trigger download)
|
||||
load_model(model_name)
|
||||
return {"status": "success", "message": f"Model {model_name} reloaded"}
|
||||
|
||||
|
||||
def _transcribe_sync(
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
<div class="nav">
|
||||
<a href="/admin" class="active">Dashboard</a>
|
||||
<a href="/admin/keys">API Keys</a>
|
||||
<a href="/admin/models">Models</a>
|
||||
<a href="/admin/logout">Logout</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
<div class="nav">
|
||||
<a href="/admin">Dashboard</a>
|
||||
<a href="/admin/keys" class="active">API Keys</a>
|
||||
<a href="/admin/models">Models</a>
|
||||
<a href="/admin/logout">Logout</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
124
src/templates/models.html
Normal file
124
src/templates/models.html
Normal file
@@ -0,0 +1,124 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Models - Whisper API Admin{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>🤖 Model Management</h1>
|
||||
<div class="nav">
|
||||
<a href="/admin">Dashboard</a>
|
||||
<a href="/admin/keys">API Keys</a>
|
||||
<a href="/admin/models" class="active">Models</a>
|
||||
<a href="/admin/logout">Logout</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Current Model Status -->
|
||||
<div class="card">
|
||||
<h2>Current Model Status</h2>
|
||||
<div id="current-model-info">
|
||||
<p><strong>Active Model:</strong> {{ current_status.name }}</p>
|
||||
<p><strong>Device:</strong> {{ current_status.device }}</p>
|
||||
<p><strong>Status:</strong>
|
||||
{% if current_status.loaded %}
|
||||
<span style="color: #48bb78;">✅ Loaded & Ready</span>
|
||||
{% elif current_status.is_downloading %}
|
||||
<span style="color: #ed8936;">⏳ Downloading ({{ current_status.download_percentage }}%)</span>
|
||||
{% else %}
|
||||
<span style="color: #718096;">⏸️ Not Loaded</span>
|
||||
{% endif %}
|
||||
</p>
|
||||
|
||||
{% if current_status.is_downloading %}
|
||||
<div style="background: #e2e8f0; border-radius: 10px; height: 20px; overflow: hidden; margin-top: 10px;">
|
||||
<div id="download-progress" style="background: linear-gradient(90deg, #667eea, #764ba2); height: 100%; width: {{ current_status.download_percentage }}%; transition: width 0.5s ease;"></div>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Available Models -->
|
||||
<div class="card">
|
||||
<h2>Available Models</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Model</th>
|
||||
<th>Size</th>
|
||||
<th>Description</th>
|
||||
<th>Status</th>
|
||||
<th>Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for model in models %}
|
||||
<tr>
|
||||
<td><strong>{{ model.name }}</strong></td>
|
||||
<td>{{ model.size }}</td>
|
||||
<td>{{ model.description }}</td>
|
||||
<td>
|
||||
{% if model.is_active %}
|
||||
<span class="badge badge-success">Active</span>
|
||||
{% elif model.is_downloaded %}
|
||||
<span class="badge" style="background: #bee3f8; color: #2c5282;">Downloaded</span>
|
||||
{% else %}
|
||||
<span class="badge badge-danger">Not Downloaded</span>
|
||||
{% endif %}
|
||||
</td>
|
||||
<td>
|
||||
{% if not model.is_active %}
|
||||
<form method="POST" action="/admin/models/switch" style="display: inline;">
|
||||
<input type="hidden" name="model_name" value="{{ model.name }}">
|
||||
<button type="submit" class="btn" style="padding: 5px 10px; font-size: 12px;"
|
||||
{% if not model.is_downloaded %}disabled{% endif %}>
|
||||
Activate
|
||||
</button>
|
||||
</form>
|
||||
{% endif %}
|
||||
|
||||
{% if model.is_downloaded %}
|
||||
<form method="POST" action="/admin/models/delete" style="display: inline; margin-left: 5px;"
|
||||
onsubmit="return confirm('Are you sure you want to delete this model?');">
|
||||
<input type="hidden" name="model_name" value="{{ model.name }}">
|
||||
<button type="submit" class="btn btn-danger" style="padding: 5px 10px; font-size: 12px;"
|
||||
{% if model.is_active %}disabled{% endif %}>
|
||||
Delete
|
||||
</button>
|
||||
</form>
|
||||
{% else %}
|
||||
<form method="POST" action="/admin/models/download" style="display: inline; margin-left: 5px;">
|
||||
<input type="hidden" name="model_name" value="{{ model.name }}">
|
||||
<button type="submit" class="btn btn-success" style="padding: 5px 10px; font-size: 12px;">
|
||||
Download
|
||||
</button>
|
||||
</form>
|
||||
{% endif %}
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- Reload Current Model -->
|
||||
<div class="card">
|
||||
<h2>Reload Current Model</h2>
|
||||
<p>If you experience issues with the current model, you can reload it. This will delete and re-download the model files.</p>
|
||||
<form method="POST" action="/admin/models/reload" onsubmit="return confirm('This will delete and re-download the current model. Continue?');">
|
||||
<button type="submit" class="btn" style="background: #ed8936;">
|
||||
🔄 Reload Model
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Auto-refresh page every 5 seconds if downloading
|
||||
{% if current_status.is_downloading %}
|
||||
setTimeout(function() {
|
||||
window.location.reload();
|
||||
}, 5000);
|
||||
{% endif %}
|
||||
</script>
|
||||
{% endblock %}
|
||||
@@ -11,7 +11,7 @@ from src.config import settings
|
||||
from src.database.db import get_db
|
||||
from src.database.models import ApiKey, UsageLog
|
||||
from src.services.stats_service import get_usage_stats, hash_api_key
|
||||
from src.services.whisper_service import get_model_status
|
||||
from src.services.whisper_service import get_model_status, get_available_models
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="src/templates")
|
||||
@@ -188,3 +188,131 @@ async def delete_key(
|
||||
db.commit()
|
||||
|
||||
return RedirectResponse(url="/admin/keys", status_code=302)
|
||||
|
||||
|
||||
@router.get("/models", response_class=HTMLResponse)
|
||||
async def manage_models(request: Request, message: Optional[str] = None, error: Optional[str] = None):
|
||||
"""Model management page"""
|
||||
try:
|
||||
check_admin_auth(request)
|
||||
except HTTPException as e:
|
||||
return RedirectResponse(url="/admin/login", status_code=302)
|
||||
|
||||
models = get_available_models()
|
||||
current_status = get_model_status()
|
||||
|
||||
return templates.TemplateResponse("models.html", {
|
||||
"request": request,
|
||||
"models": models,
|
||||
"current_status": current_status,
|
||||
"message": message,
|
||||
"error": error
|
||||
})
|
||||
|
||||
|
||||
@router.post("/models/switch")
|
||||
async def switch_model_admin(
|
||||
request: Request,
|
||||
model_name: str = Form(...)
|
||||
):
|
||||
"""Switch to a different model"""
|
||||
try:
|
||||
check_admin_auth(request)
|
||||
except HTTPException as e:
|
||||
return RedirectResponse(url="/admin/login", status_code=302)
|
||||
|
||||
try:
|
||||
from src.services.whisper_service import switch_model
|
||||
result = switch_model(model_name)
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?message={result['message']}",
|
||||
status_code=302
|
||||
)
|
||||
except Exception as e:
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?error={str(e)}",
|
||||
status_code=302
|
||||
)
|
||||
|
||||
|
||||
@router.post("/models/delete")
|
||||
async def delete_model_admin(
|
||||
request: Request,
|
||||
model_name: str = Form(...)
|
||||
):
|
||||
"""Delete a model"""
|
||||
try:
|
||||
check_admin_auth(request)
|
||||
except HTTPException as e:
|
||||
return RedirectResponse(url="/admin/login", status_code=302)
|
||||
|
||||
try:
|
||||
from src.services.whisper_service import delete_model
|
||||
result = delete_model(model_name)
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?message={result['message']}",
|
||||
status_code=302
|
||||
)
|
||||
except Exception as e:
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?error={str(e)}",
|
||||
status_code=302
|
||||
)
|
||||
|
||||
|
||||
@router.post("/models/download")
|
||||
async def download_model_admin(
|
||||
request: Request,
|
||||
model_name: str = Form(...)
|
||||
):
|
||||
"""Download a model"""
|
||||
try:
|
||||
check_admin_auth(request)
|
||||
except HTTPException as e:
|
||||
return RedirectResponse(url="/admin/login", status_code=302)
|
||||
|
||||
try:
|
||||
from src.services.whisper_service import load_model
|
||||
# Start download in background (non-blocking for API)
|
||||
import threading
|
||||
def download():
|
||||
try:
|
||||
load_model(model_name)
|
||||
except Exception as e:
|
||||
print(f"Error downloading model {model_name}: {e}")
|
||||
|
||||
thread = threading.Thread(target=download)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?message=Started downloading model {model_name}",
|
||||
status_code=302
|
||||
)
|
||||
except Exception as e:
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?error={str(e)}",
|
||||
status_code=302
|
||||
)
|
||||
|
||||
|
||||
@router.post("/models/reload")
|
||||
async def reload_model_admin(request: Request):
|
||||
"""Reload current model"""
|
||||
try:
|
||||
check_admin_auth(request)
|
||||
except HTTPException as e:
|
||||
return RedirectResponse(url="/admin/login", status_code=302)
|
||||
|
||||
try:
|
||||
from src.services.whisper_service import reload_model
|
||||
result = reload_model()
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?message={result['message']}",
|
||||
status_code=302
|
||||
)
|
||||
except Exception as e:
|
||||
return RedirectResponse(
|
||||
url=f"/admin/models?error={str(e)}",
|
||||
status_code=302
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user