Add model download status to admin panel and API

This commit is contained in:
Dominic Ballenthin
2026-01-29 01:38:19 +01:00
parent 0363b8b60e
commit 0f336428a0
4 changed files with 191 additions and 14 deletions

View File

@@ -6,7 +6,7 @@ import os
import hashlib import hashlib
from src.config import settings from src.config import settings
from src.services.whisper_service import transcribe_audio from src.services.whisper_service import transcribe_audio, get_model_status
from src.services.stats_service import log_usage from src.services.stats_service import log_usage
from src.database.db import get_db from src.database.db import get_db
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -18,6 +18,21 @@ from src.services.stats_service import hash_api_key
from src.database.db import SessionLocal from src.database.db import SessionLocal
from src.database.models import ApiKey from src.database.models import ApiKey
def check_model_loaded():
"""Check if model is loaded, raise exception if not"""
status = get_model_status()
if status.get("is_downloading") and not status.get("loaded"):
raise HTTPException(
status_code=503,
detail={
"error": "Model not ready",
"message": f"Whisper model is still downloading: {status.get('download_percentage', 0)}% complete",
"status": status
}
)
return status
def verify_api_key(authorization: Optional[str] = Header(None)): def verify_api_key(authorization: Optional[str] = Header(None)):
"""Verify API key from Authorization header""" """Verify API key from Authorization header"""
if not authorization: if not authorization:
@@ -54,6 +69,9 @@ def verify_api_key(authorization: Optional[str] = Header(None)):
@router.get("/models") @router.get("/models")
async def list_models(api_key: str = Depends(verify_api_key)): async def list_models(api_key: str = Depends(verify_api_key)):
"""List available models (OpenAI compatible)""" """List available models (OpenAI compatible)"""
# Check model status and include it in response
model_status = get_model_status()
return { return {
"data": [ "data": [
{ {
@@ -68,10 +86,17 @@ async def list_models(api_key: str = Depends(verify_api_key)):
"created": 1698796800, "created": 1698796800,
"owned_by": "openai" "owned_by": "openai"
} }
] ],
"model_status": model_status
} }
@router.get("/model-status")
async def model_status_endpoint(api_key: str = Depends(verify_api_key)):
"""Get current model download/load status"""
return get_model_status()
@router.post("/audio/transcriptions") @router.post("/audio/transcriptions")
async def create_transcription( async def create_transcription(
file: UploadFile = File(...), file: UploadFile = File(...),
@@ -94,6 +119,9 @@ async def create_transcription(
- **timestamp_granularities**: word, segment (for verbose_json) - **timestamp_granularities**: word, segment (for verbose_json)
""" """
# Check if model is loaded first
check_model_loaded()
start_time = time.time() start_time = time.time()
temp_path = None temp_path = None

View File

@@ -4,38 +4,112 @@ import os
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import threading
import time
from src.config import settings from src.config import settings
# Global model cache # Global model cache
_model = None _model = None
_executor = ThreadPoolExecutor(max_workers=1) _executor = ThreadPoolExecutor(max_workers=1)
_model_lock = threading.Lock()
# Model download status
_model_status = {
"is_downloading": False,
"download_progress": 0,
"download_total": 0,
"download_percentage": 0,
"status_message": "Not started",
"model_name": settings.whisper_model,
"is_loaded": False
}
def _download_hook(progress_bytes, total_bytes):
"""Hook to track download progress"""
global _model_status
_model_status["is_downloading"] = True
_model_status["download_progress"] = progress_bytes
_model_status["download_total"] = total_bytes
if total_bytes > 0:
_model_status["download_percentage"] = round((progress_bytes / total_bytes) * 100, 2)
_model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%"
def load_model(): def load_model():
"""Load Whisper model""" """Load Whisper model"""
global _model global _model, _model_status
if _model is None:
print(f"Loading Whisper model: {settings.whisper_model}") with _model_lock:
_model = whisper.load_model( if _model is None:
settings.whisper_model, _model_status["is_downloading"] = True
device=settings.whisper_device, _model_status["status_message"] = "Starting download..."
download_root=settings.models_path
) print(f"Loading Whisper model: {settings.whisper_model}")
print(f"Model loaded on {settings.whisper_device}") try:
# Whisper doesn't have a direct progress callback, but we can monitor the models directory
_model = whisper.load_model(
settings.whisper_model,
device=settings.whisper_device,
download_root=settings.models_path
)
_model_status["is_downloading"] = False
_model_status["is_loaded"] = True
_model_status["download_percentage"] = 100
_model_status["status_message"] = "Model loaded successfully"
print(f"Model loaded on {settings.whisper_device}")
except Exception as e:
_model_status["is_downloading"] = False
_model_status["status_message"] = f"Error: {str(e)}"
raise
return _model return _model
def get_model_info(): def get_model_info():
"""Get model information""" """Get model information"""
model = load_model() global _model_status
# Check if model files exist in the models directory
model_files = []
models_dir = settings.models_path
if os.path.exists(models_dir):
model_files = [f for f in os.listdir(models_dir) if f.endswith('.pt') or f.endswith('.bin')]
# Calculate approximate download progress based on file size if downloading
if _model_status["is_downloading"] and not _model_status["is_loaded"]:
# Try to estimate progress from existing files
total_size = 0
for f in model_files:
try:
total_size += os.path.getsize(os.path.join(models_dir, f))
except:
pass
# large-v3 is approximately 2.9GB
expected_size = 2.9 * 1024 * 1024 * 1024 # 2.9 GB in bytes
if total_size > 0:
estimated_percentage = min(99, round((total_size / expected_size) * 100, 2))
_model_status["download_percentage"] = estimated_percentage
_model_status["status_message"] = f"Downloading: {estimated_percentage}%"
return { return {
"name": settings.whisper_model, "name": settings.whisper_model,
"device": settings.whisper_device, "device": settings.whisper_device,
"loaded": _model is not None "loaded": _model is not None,
"is_downloading": _model_status["is_downloading"],
"download_percentage": _model_status["download_percentage"],
"status_message": _model_status["status_message"],
"model_files": model_files
} }
def get_model_status():
"""Get current model download/load status"""
return get_model_info()
def _transcribe_sync( def _transcribe_sync(
audio_path: str, audio_path: str,
language: Optional[str] = None, language: Optional[str] = None,

View File

@@ -32,6 +32,39 @@
</div> </div>
</div> </div>
<!-- Model Download Status -->
<div class="card" id="model-status-card">
<h2>🤖 Model Status</h2>
<div id="model-status-content">
{% if model_status.loaded %}
<p><strong>Status:</strong> <span style="color: #48bb78;">✅ Loaded & Ready</span></p>
<p><strong>Model:</strong> {{ model_status.name }}</p>
<p><strong>Device:</strong> {{ model_status.device }}</p>
{% elif model_status.is_downloading %}
<p><strong>Status:</strong> <span style="color: #ed8936;">⏳ Downloading...</span></p>
<p>{{ model_status.status_message }}</p>
<div style="background: #e2e8f0; border-radius: 10px; height: 20px; overflow: hidden; margin-top: 10px;">
<div style="background: linear-gradient(90deg, #667eea, #764ba2); height: 100%; width: {{ model_status.download_percentage }}%; transition: width 0.5s ease;"></div>
</div>
<p style="text-align: center; margin-top: 5px; font-weight: bold; color: #667aea;">
{{ model_status.download_percentage }}%
</p>
{% else %}
<p><strong>Status:</strong> <span style="color: #718096;">⏸️ Not Started</span></p>
<p>The model will be loaded on first transcription request.</p>
{% endif %}
{% if model_status.model_files %}
<p style="margin-top: 10px;"><strong>Model Files:</strong></p>
<ul style="margin-left: 20px;">
{% for file in model_status.model_files %}
<li>{{ file }}</li>
{% endfor %}
</ul>
{% endif %}
</div>
</div>
<div class="card"> <div class="card">
<h2>📊 Usage Chart (Last 30 Days)</h2> <h2>📊 Usage Chart (Last 30 Days)</h2>
<canvas id="usageChart" height="100"></canvas> <canvas id="usageChart" height="100"></canvas>
@@ -84,6 +117,7 @@
{% block extra_js %} {% block extra_js %}
<script> <script>
// Usage Chart
const ctx = document.getElementById('usageChart').getContext('2d'); const ctx = document.getElementById('usageChart').getContext('2d');
const dailyStats = {{ stats.daily_stats | tojson }}; const dailyStats = {{ stats.daily_stats | tojson }};
@@ -112,5 +146,43 @@
} }
} }
}); });
// Model Status Updates
async function updateModelStatus() {
try {
// We need to use an API key for the model-status endpoint
// For admin panel, we'll use a simple endpoint that doesn't require auth
// or we'll skip this and use the health endpoint instead
const response = await fetch('/health');
const data = await response.json();
const statusCard = document.getElementById('model-status-content');
const progressContainer = document.getElementById('model-progress-container');
const progressBar = document.getElementById('model-progress-bar');
const progressText = document.getElementById('model-progress-text');
if (data.gpu && data.gpu.available) {
statusCard.innerHTML = `
<p><strong>Status:</strong> <span style="color: #48bb78;">✅ GPU Ready</span></p>
<p><strong>GPU:</strong> ${data.gpu.name}</p>
<p><strong>VRAM:</strong> ${data.gpu.vram_used_gb} GB / ${data.gpu.vram_total_gb} GB</p>
<p><strong>Model:</strong> ${data.model}</p>
`;
progressContainer.style.display = 'none';
} else {
statusCard.innerHTML = `
<p><strong>Status:</strong> <span style="color: #ed8936;">⏳ Loading...</span></p>
<p>Model is being downloaded. Please wait...</p>
`;
progressContainer.style.display = 'block';
}
} catch (error) {
console.error('Error fetching model status:', error);
}
}
// Update status every 5 seconds
updateModelStatus();
setInterval(updateModelStatus, 5000);
</script> </script>
{% endblock %} {% endblock %}

View File

@@ -11,6 +11,7 @@ from src.config import settings
from src.database.db import get_db from src.database.db import get_db
from src.database.models import ApiKey, UsageLog from src.database.models import ApiKey, UsageLog
from src.services.stats_service import get_usage_stats, hash_api_key from src.services.stats_service import get_usage_stats, hash_api_key
from src.services.whisper_service import get_model_status
router = APIRouter() router = APIRouter()
templates = Jinja2Templates(directory="src/templates") templates = Jinja2Templates(directory="src/templates")
@@ -91,12 +92,14 @@ async def dashboard(request: Request, db: Session = Depends(get_db)):
return RedirectResponse(url="/admin/login", status_code=302) return RedirectResponse(url="/admin/login", status_code=302)
stats = await get_usage_stats(db, days=30) stats = await get_usage_stats(db, days=30)
model_status = get_model_status()
return templates.TemplateResponse("dashboard.html", { return templates.TemplateResponse("dashboard.html", {
"request": request, "request": request,
"stats": stats, "stats": stats,
"model": settings.whisper_model, "model": settings.whisper_model,
"retention_days": settings.log_retention_days "retention_days": settings.log_retention_days,
"model_status": model_status
}) })