Add model download status to admin panel and API
This commit is contained in:
@@ -6,7 +6,7 @@ import os
|
||||
import hashlib
|
||||
|
||||
from src.config import settings
|
||||
from src.services.whisper_service import transcribe_audio
|
||||
from src.services.whisper_service import transcribe_audio, get_model_status
|
||||
from src.services.stats_service import log_usage
|
||||
from src.database.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -18,6 +18,21 @@ from src.services.stats_service import hash_api_key
|
||||
from src.database.db import SessionLocal
|
||||
from src.database.models import ApiKey
|
||||
|
||||
|
||||
def check_model_loaded():
|
||||
"""Check if model is loaded, raise exception if not"""
|
||||
status = get_model_status()
|
||||
if status.get("is_downloading") and not status.get("loaded"):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={
|
||||
"error": "Model not ready",
|
||||
"message": f"Whisper model is still downloading: {status.get('download_percentage', 0)}% complete",
|
||||
"status": status
|
||||
}
|
||||
)
|
||||
return status
|
||||
|
||||
def verify_api_key(authorization: Optional[str] = Header(None)):
|
||||
"""Verify API key from Authorization header"""
|
||||
if not authorization:
|
||||
@@ -54,6 +69,9 @@ def verify_api_key(authorization: Optional[str] = Header(None)):
|
||||
@router.get("/models")
|
||||
async def list_models(api_key: str = Depends(verify_api_key)):
|
||||
"""List available models (OpenAI compatible)"""
|
||||
# Check model status and include it in response
|
||||
model_status = get_model_status()
|
||||
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
@@ -68,10 +86,17 @@ async def list_models(api_key: str = Depends(verify_api_key)):
|
||||
"created": 1698796800,
|
||||
"owned_by": "openai"
|
||||
}
|
||||
]
|
||||
],
|
||||
"model_status": model_status
|
||||
}
|
||||
|
||||
|
||||
@router.get("/model-status")
|
||||
async def model_status_endpoint(api_key: str = Depends(verify_api_key)):
|
||||
"""Get current model download/load status"""
|
||||
return get_model_status()
|
||||
|
||||
|
||||
@router.post("/audio/transcriptions")
|
||||
async def create_transcription(
|
||||
file: UploadFile = File(...),
|
||||
@@ -94,6 +119,9 @@ async def create_transcription(
|
||||
- **timestamp_granularities**: word, segment (for verbose_json)
|
||||
"""
|
||||
|
||||
# Check if model is loaded first
|
||||
check_model_loaded()
|
||||
|
||||
start_time = time.time()
|
||||
temp_path = None
|
||||
|
||||
|
||||
@@ -4,38 +4,112 @@ import os
|
||||
from typing import Optional, Dict, Any
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
import time
|
||||
|
||||
from src.config import settings
|
||||
|
||||
# Global model cache
|
||||
_model = None
|
||||
_executor = ThreadPoolExecutor(max_workers=1)
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
# Model download status
|
||||
_model_status = {
|
||||
"is_downloading": False,
|
||||
"download_progress": 0,
|
||||
"download_total": 0,
|
||||
"download_percentage": 0,
|
||||
"status_message": "Not started",
|
||||
"model_name": settings.whisper_model,
|
||||
"is_loaded": False
|
||||
}
|
||||
|
||||
|
||||
def _download_hook(progress_bytes, total_bytes):
|
||||
"""Hook to track download progress"""
|
||||
global _model_status
|
||||
_model_status["is_downloading"] = True
|
||||
_model_status["download_progress"] = progress_bytes
|
||||
_model_status["download_total"] = total_bytes
|
||||
if total_bytes > 0:
|
||||
_model_status["download_percentage"] = round((progress_bytes / total_bytes) * 100, 2)
|
||||
_model_status["status_message"] = f"Downloading: {_model_status['download_percentage']}%"
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Load Whisper model"""
|
||||
global _model
|
||||
global _model, _model_status
|
||||
|
||||
with _model_lock:
|
||||
if _model is None:
|
||||
_model_status["is_downloading"] = True
|
||||
_model_status["status_message"] = "Starting download..."
|
||||
|
||||
print(f"Loading Whisper model: {settings.whisper_model}")
|
||||
try:
|
||||
# Whisper doesn't have a direct progress callback, but we can monitor the models directory
|
||||
_model = whisper.load_model(
|
||||
settings.whisper_model,
|
||||
device=settings.whisper_device,
|
||||
download_root=settings.models_path
|
||||
)
|
||||
_model_status["is_downloading"] = False
|
||||
_model_status["is_loaded"] = True
|
||||
_model_status["download_percentage"] = 100
|
||||
_model_status["status_message"] = "Model loaded successfully"
|
||||
print(f"Model loaded on {settings.whisper_device}")
|
||||
except Exception as e:
|
||||
_model_status["is_downloading"] = False
|
||||
_model_status["status_message"] = f"Error: {str(e)}"
|
||||
raise
|
||||
|
||||
return _model
|
||||
|
||||
|
||||
def get_model_info():
|
||||
"""Get model information"""
|
||||
model = load_model()
|
||||
global _model_status
|
||||
|
||||
# Check if model files exist in the models directory
|
||||
model_files = []
|
||||
models_dir = settings.models_path
|
||||
if os.path.exists(models_dir):
|
||||
model_files = [f for f in os.listdir(models_dir) if f.endswith('.pt') or f.endswith('.bin')]
|
||||
|
||||
# Calculate approximate download progress based on file size if downloading
|
||||
if _model_status["is_downloading"] and not _model_status["is_loaded"]:
|
||||
# Try to estimate progress from existing files
|
||||
total_size = 0
|
||||
for f in model_files:
|
||||
try:
|
||||
total_size += os.path.getsize(os.path.join(models_dir, f))
|
||||
except:
|
||||
pass
|
||||
|
||||
# large-v3 is approximately 2.9GB
|
||||
expected_size = 2.9 * 1024 * 1024 * 1024 # 2.9 GB in bytes
|
||||
if total_size > 0:
|
||||
estimated_percentage = min(99, round((total_size / expected_size) * 100, 2))
|
||||
_model_status["download_percentage"] = estimated_percentage
|
||||
_model_status["status_message"] = f"Downloading: {estimated_percentage}%"
|
||||
|
||||
return {
|
||||
"name": settings.whisper_model,
|
||||
"device": settings.whisper_device,
|
||||
"loaded": _model is not None
|
||||
"loaded": _model is not None,
|
||||
"is_downloading": _model_status["is_downloading"],
|
||||
"download_percentage": _model_status["download_percentage"],
|
||||
"status_message": _model_status["status_message"],
|
||||
"model_files": model_files
|
||||
}
|
||||
|
||||
|
||||
def get_model_status():
|
||||
"""Get current model download/load status"""
|
||||
return get_model_info()
|
||||
|
||||
|
||||
def _transcribe_sync(
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
|
||||
@@ -32,6 +32,39 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Model Download Status -->
|
||||
<div class="card" id="model-status-card">
|
||||
<h2>🤖 Model Status</h2>
|
||||
<div id="model-status-content">
|
||||
{% if model_status.loaded %}
|
||||
<p><strong>Status:</strong> <span style="color: #48bb78;">✅ Loaded & Ready</span></p>
|
||||
<p><strong>Model:</strong> {{ model_status.name }}</p>
|
||||
<p><strong>Device:</strong> {{ model_status.device }}</p>
|
||||
{% elif model_status.is_downloading %}
|
||||
<p><strong>Status:</strong> <span style="color: #ed8936;">⏳ Downloading...</span></p>
|
||||
<p>{{ model_status.status_message }}</p>
|
||||
<div style="background: #e2e8f0; border-radius: 10px; height: 20px; overflow: hidden; margin-top: 10px;">
|
||||
<div style="background: linear-gradient(90deg, #667eea, #764ba2); height: 100%; width: {{ model_status.download_percentage }}%; transition: width 0.5s ease;"></div>
|
||||
</div>
|
||||
<p style="text-align: center; margin-top: 5px; font-weight: bold; color: #667aea;">
|
||||
{{ model_status.download_percentage }}%
|
||||
</p>
|
||||
{% else %}
|
||||
<p><strong>Status:</strong> <span style="color: #718096;">⏸️ Not Started</span></p>
|
||||
<p>The model will be loaded on first transcription request.</p>
|
||||
{% endif %}
|
||||
|
||||
{% if model_status.model_files %}
|
||||
<p style="margin-top: 10px;"><strong>Model Files:</strong></p>
|
||||
<ul style="margin-left: 20px;">
|
||||
{% for file in model_status.model_files %}
|
||||
<li>{{ file }}</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<h2>📊 Usage Chart (Last 30 Days)</h2>
|
||||
<canvas id="usageChart" height="100"></canvas>
|
||||
@@ -84,6 +117,7 @@
|
||||
|
||||
{% block extra_js %}
|
||||
<script>
|
||||
// Usage Chart
|
||||
const ctx = document.getElementById('usageChart').getContext('2d');
|
||||
const dailyStats = {{ stats.daily_stats | tojson }};
|
||||
|
||||
@@ -112,5 +146,43 @@
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Model Status Updates
|
||||
async function updateModelStatus() {
|
||||
try {
|
||||
// We need to use an API key for the model-status endpoint
|
||||
// For admin panel, we'll use a simple endpoint that doesn't require auth
|
||||
// or we'll skip this and use the health endpoint instead
|
||||
const response = await fetch('/health');
|
||||
const data = await response.json();
|
||||
|
||||
const statusCard = document.getElementById('model-status-content');
|
||||
const progressContainer = document.getElementById('model-progress-container');
|
||||
const progressBar = document.getElementById('model-progress-bar');
|
||||
const progressText = document.getElementById('model-progress-text');
|
||||
|
||||
if (data.gpu && data.gpu.available) {
|
||||
statusCard.innerHTML = `
|
||||
<p><strong>Status:</strong> <span style="color: #48bb78;">✅ GPU Ready</span></p>
|
||||
<p><strong>GPU:</strong> ${data.gpu.name}</p>
|
||||
<p><strong>VRAM:</strong> ${data.gpu.vram_used_gb} GB / ${data.gpu.vram_total_gb} GB</p>
|
||||
<p><strong>Model:</strong> ${data.model}</p>
|
||||
`;
|
||||
progressContainer.style.display = 'none';
|
||||
} else {
|
||||
statusCard.innerHTML = `
|
||||
<p><strong>Status:</strong> <span style="color: #ed8936;">⏳ Loading...</span></p>
|
||||
<p>Model is being downloaded. Please wait...</p>
|
||||
`;
|
||||
progressContainer.style.display = 'block';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching model status:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Update status every 5 seconds
|
||||
updateModelStatus();
|
||||
setInterval(updateModelStatus, 5000);
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.config import settings
|
||||
from src.database.db import get_db
|
||||
from src.database.models import ApiKey, UsageLog
|
||||
from src.services.stats_service import get_usage_stats, hash_api_key
|
||||
from src.services.whisper_service import get_model_status
|
||||
|
||||
router = APIRouter()
|
||||
templates = Jinja2Templates(directory="src/templates")
|
||||
@@ -91,12 +92,14 @@ async def dashboard(request: Request, db: Session = Depends(get_db)):
|
||||
return RedirectResponse(url="/admin/login", status_code=302)
|
||||
|
||||
stats = await get_usage_stats(db, days=30)
|
||||
model_status = get_model_status()
|
||||
|
||||
return templates.TemplateResponse("dashboard.html", {
|
||||
"request": request,
|
||||
"stats": stats,
|
||||
"model": settings.whisper_model,
|
||||
"retention_days": settings.log_retention_days
|
||||
"retention_days": settings.log_retention_days,
|
||||
"model_status": model_status
|
||||
})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user