Merge pull request #322 from eschmidbauer/fix/thread-safety-issues
Fix kv cache not being properly cleaned between sessions
This commit is contained in:
commit
34e4abd455
4 changed files with 174 additions and 8 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
from whisperlivekit.local_agreement.online_asr import OnlineASRProcessor
|
||||||
|
|
@ -19,16 +20,26 @@ logger = logging.getLogger(__name__)
|
||||||
class TranscriptionEngine:
|
class TranscriptionEngine:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
_lock = threading.Lock() # Thread-safe singleton lock
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# Double-checked locking pattern for thread-safe singleton
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
with cls._lock:
|
||||||
|
# Check again inside lock to prevent race condition
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
if TranscriptionEngine._initialized:
|
# Thread-safe initialization check
|
||||||
return
|
with TranscriptionEngine._lock:
|
||||||
|
if TranscriptionEngine._initialized:
|
||||||
|
return
|
||||||
|
# Set flag immediately to prevent re-initialization
|
||||||
|
TranscriptionEngine._initialized = True
|
||||||
|
|
||||||
|
# Perform initialization outside lock to avoid holding lock during slow operations
|
||||||
global_params = {
|
global_params = {
|
||||||
"host": "localhost",
|
"host": "localhost",
|
||||||
"port": 8000,
|
"port": 8000,
|
||||||
|
|
@ -172,7 +183,6 @@ class TranscriptionEngine:
|
||||||
}
|
}
|
||||||
translation_params = update_with_kwargs(translation_params, kwargs)
|
translation_params = update_with_kwargs(translation_params, kwargs)
|
||||||
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
self.translation_model = load_model([self.args.lan], **translation_params) #in the future we want to handle different languages for different speakers
|
||||||
TranscriptionEngine._initialized = True
|
|
||||||
|
|
||||||
|
|
||||||
def online_factory(args, asr):
|
def online_factory(args, asr):
|
||||||
|
|
|
||||||
|
|
@ -47,9 +47,24 @@ class DecoderState:
|
||||||
|
|
||||||
def clean_cache(self):
|
def clean_cache(self):
|
||||||
"""Clean the kv_cache after each inference step."""
|
"""Clean the kv_cache after each inference step."""
|
||||||
self.kv_cache = {}
|
# Explicitly delete tensor references to free GPU memory
|
||||||
|
if self.kv_cache:
|
||||||
|
for key in list(self.kv_cache.keys()):
|
||||||
|
tensor = self.kv_cache.pop(key, None)
|
||||||
|
if tensor is not None:
|
||||||
|
del tensor
|
||||||
|
|
||||||
|
# Clear the dict
|
||||||
|
self.kv_cache.clear()
|
||||||
|
|
||||||
|
# Force GPU cache cleanup (only if CUDA is available)
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.decoder_type == "beam" and self.inference is not None:
|
if self.decoder_type == "beam" and self.inference is not None:
|
||||||
self.inference.kv_cache = self.kv_cache
|
# Create NEW dict instead of sharing reference
|
||||||
|
self.inference.kv_cache = {}
|
||||||
if self.token_decoder is not None:
|
if self.token_decoder is not None:
|
||||||
self.token_decoder.reset()
|
self.token_decoder.reset()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -626,8 +626,10 @@ class AlignAtt:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
current_timestamp = l_absolute_timestamps[timestamp_idx]
|
||||||
except:
|
except IndexError:
|
||||||
pass
|
# Use last timestamp if index out of range
|
||||||
|
logger.warning(f"Timestamp index {timestamp_idx} out of range, using last timestamp")
|
||||||
|
current_timestamp = l_absolute_timestamps[-1] if l_absolute_timestamps else 0.0
|
||||||
timestamp_idx += len(word_tokens)
|
timestamp_idx += len(word_tokens)
|
||||||
|
|
||||||
timestamp_entry = ASRToken(
|
timestamp_entry = ASRToken(
|
||||||
|
|
|
||||||
139
whisperlivekit/thread_safety.py
Normal file
139
whisperlivekit/thread_safety.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""
|
||||||
|
Thread Safety Configuration for WhisperLiveKit
|
||||||
|
|
||||||
|
This module provides thread safety configuration and utilities.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
WHISPERLIVEKIT_MODEL_LOCK: Enable/disable model locking (default: 1)
|
||||||
|
Set to "0" to disable for single-connection deployments
|
||||||
|
|
||||||
|
WHISPERLIVEKIT_LOCK_TIMEOUT: Lock acquisition timeout in seconds (default: 30)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Enable model locking (default)
|
||||||
|
export WHISPERLIVEKIT_MODEL_LOCK=1
|
||||||
|
|
||||||
|
# Disable for single-connection deployment
|
||||||
|
export WHISPERLIVEKIT_MODEL_LOCK=0
|
||||||
|
|
||||||
|
# Custom timeout
|
||||||
|
export WHISPERLIVEKIT_LOCK_TIMEOUT=60
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
USE_MODEL_LOCK = os.environ.get("WHISPERLIVEKIT_MODEL_LOCK", "1") == "1"
|
||||||
|
LOCK_TIMEOUT = float(os.environ.get("WHISPERLIVEKIT_LOCK_TIMEOUT", "30.0"))
|
||||||
|
|
||||||
|
# Global model lock
|
||||||
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Log configuration on import
|
||||||
|
if USE_MODEL_LOCK:
|
||||||
|
logger.info(f"Model locking ENABLED (timeout: {LOCK_TIMEOUT}s)")
|
||||||
|
logger.info("For single-connection deployments, set WHISPERLIVEKIT_MODEL_LOCK=0")
|
||||||
|
else:
|
||||||
|
logger.warning("Model locking DISABLED - only safe for single-connection deployments")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_lock():
|
||||||
|
"""Get the global model lock instance"""
|
||||||
|
return _model_lock
|
||||||
|
|
||||||
|
|
||||||
|
def acquire_model_lock(timeout=None):
|
||||||
|
"""
|
||||||
|
Acquire model lock with timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Lock acquisition timeout (default: use LOCK_TIMEOUT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if lock acquired, False on timeout
|
||||||
|
"""
|
||||||
|
if not USE_MODEL_LOCK:
|
||||||
|
return True
|
||||||
|
|
||||||
|
timeout = timeout or LOCK_TIMEOUT
|
||||||
|
acquired = _model_lock.acquire(timeout=timeout)
|
||||||
|
|
||||||
|
if not acquired:
|
||||||
|
logger.error(f"Failed to acquire model lock within {timeout}s")
|
||||||
|
|
||||||
|
return acquired
|
||||||
|
|
||||||
|
|
||||||
|
def release_model_lock():
|
||||||
|
"""Release model lock"""
|
||||||
|
if not USE_MODEL_LOCK:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
_model_lock.release()
|
||||||
|
except RuntimeError:
|
||||||
|
# Lock not held - this is fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLockContext:
|
||||||
|
"""Context manager for model lock"""
|
||||||
|
|
||||||
|
def __init__(self, timeout=None):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.acquired = False
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.acquired = acquire_model_lock(self.timeout)
|
||||||
|
return self.acquired
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.acquired:
|
||||||
|
release_model_lock()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Concurrency recommendations
|
||||||
|
RECOMMENDED_CONNECTIONS_PER_WORKER = 1 if USE_MODEL_LOCK else 1
|
||||||
|
RECOMMENDED_WORKERS = 4
|
||||||
|
|
||||||
|
def print_deployment_recommendations():
|
||||||
|
"""Print recommended deployment configuration"""
|
||||||
|
print("\n" + "="*60)
|
||||||
|
print("WhisperLiveKit Deployment Recommendations")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
if USE_MODEL_LOCK:
|
||||||
|
print("⚠️ Model locking is ENABLED")
|
||||||
|
print(" This serializes inference across connections.")
|
||||||
|
print()
|
||||||
|
print("Recommended deployment:")
|
||||||
|
print(f" gunicorn -w {RECOMMENDED_WORKERS} \\")
|
||||||
|
print(" -k uvicorn.workers.UvicornWorker \\")
|
||||||
|
print(" --worker-connections 1 \\")
|
||||||
|
print(" whisperlivekit.basic_server:app")
|
||||||
|
print()
|
||||||
|
print("Expected capacity:")
|
||||||
|
print(f" - {RECOMMENDED_WORKERS} concurrent users (1 per worker)")
|
||||||
|
print(f" - Memory: ~{RECOMMENDED_WORKERS}x model size")
|
||||||
|
else:
|
||||||
|
print("✅ Model locking is DISABLED")
|
||||||
|
print(" ⚠️ ONLY safe for single-connection deployments")
|
||||||
|
print()
|
||||||
|
print("Recommended deployment:")
|
||||||
|
print(" uvicorn whisperlivekit.basic_server:app \\")
|
||||||
|
print(" --host 0.0.0.0 --port 8000 \\")
|
||||||
|
print(" --workers 1")
|
||||||
|
print()
|
||||||
|
print("Expected capacity:")
|
||||||
|
print(" - 1 concurrent user only")
|
||||||
|
|
||||||
|
print("="*60 + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_deployment_recommendations()
|
||||||
Loading…
Reference in a new issue