qwen3: reuse encoder kv cache

This commit is contained in:
Quentin Fuxa 2026-03-15 22:31:39 +01:00
parent f24481dc29
commit dd48997674

View file

@ -66,6 +66,64 @@ class Qwen3SimulConfig:
max_alignment_heads: int = 20 # Use only top N alignment heads
@dataclass
class _AudioEmbedCache:
"""Cached audio encoder outputs for incremental encoding.
The Qwen3-ASR audio encoder processes mel features in chunks of
``n_window * 2`` mel frames with windowed self-attention spanning
``n_window_infer`` mel frames (800 for both 0.6B and 1.7B = 8s of
audio). Within one attention window chunks can attend to each other,
but across windows they cannot.
We cache the audio embeddings (output of ``get_audio_features``) for
all *complete attention windows* whose input mel frames are unchanged.
When the audio buffer grows, only the tail (last incomplete window +
new audio) is re-encoded through the audio encoder, and the result is
concatenated with the cached prefix.
When the audio buffer is trimmed from the front (e.g. max_len exceeded),
the cache is fully invalidated and rebuilt on the next call.
"""
# Number of audio *samples* (PCM @ 16kHz) that have been fully encoded.
# This always equals the number of samples whose mel features were fed
# to the audio encoder for the cached embeddings.
encoded_samples: int = 0
# Cached audio embeddings tensor, shape (1, n_cached_tokens, hidden_dim).
# None means "no cache yet".
embeddings: Optional[torch.Tensor] = None
# Number of mel frames that produced ``embeddings``.
# Used to verify cache validity (mel length must match).
encoded_mel_frames: int = 0
# Number of audio tokens (embeddings.shape[1]) that are in *complete*
# attention windows and can be safely reused. Tokens from the last
# (potentially incomplete) window are always re-encoded.
stable_tokens: int = 0
def trim_front(self, trim_samples: int, sample_rate: int = 16000):
"""Invalidate cache entries for audio trimmed from the front.
Called when ``insert_audio_chunk`` trims the buffer. Rather than
attempting complex partial invalidation (which could introduce subtle
bugs if the mel/token math doesn't align perfectly), we simply reset
the cache. The next ``_encode_audio_cached`` call will rebuild it.
This is safe because trimming only happens when the audio buffer
exceeds ``audio_max_len`` (~15s), which is relatively infrequent.
"""
self.reset()
def reset(self):
"""Fully invalidate the cache."""
self.encoded_samples = 0
self.embeddings = None
self.encoded_mel_frames = 0
self.stable_tokens = 0
@dataclass
class Qwen3SimulState:
"""Per-session mutable state for Qwen3 SimulStreaming."""
@ -89,6 +147,9 @@ class Qwen3SimulState:
detected_language: Optional[str] = None
last_infer_samples: int = 0 # audio_buffer length at last inference
# Audio embedding cache for incremental encoding
audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache)
class Qwen3SimulStreamingASR:
"""
@ -346,6 +407,8 @@ class Qwen3SimulStreamingOnlineProcessor:
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
# Adjust throttle counter so it tracks position within the trimmed buffer
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
# Trim audio embedding cache to match
self.state.audio_cache.trim_front(trim, self.SAMPLING_RATE)
def start_silence(self) -> Tuple[List[ASRToken], float]:
"""Handle start of silence -- flush all pending tokens.
@ -388,6 +451,248 @@ class Qwen3SimulStreamingOnlineProcessor:
"""Get the current unvalidated buffer."""
return Transcript.from_tokens(tokens=self.buffer, sep='')
def _encode_audio_cached(self) -> Optional[torch.Tensor]:
"""Encode audio buffer using cached embeddings where possible.
Returns the full audio embeddings tensor (n_audio_tokens, hidden_dim),
or None if caching is not possible (caller should fall back to the
processor-based path).
Caching strategy:
- The audio encoder uses windowed attention with window size
``n_window_infer`` (800 mel frames = 8s of audio for both the
0.6B and 1.7B models).
- Tokens within one window can attend to each other, but not across
windows. So all tokens in *complete* windows are deterministic
and can be cached.
- We only re-encode the *tail* of the audio (from the last complete
window boundary onward) through the audio encoder.
- The cached prefix embeddings are concatenated with the new tail
embeddings to produce the full result.
"""
asr = self.asr
state = self.state
cache = state.audio_cache
if len(state.audio_buffer) == 0:
return None
try:
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
# Step 1: Compute mel features for the FULL audio.
# WhisperFeatureExtractor is fast (CPU FFT), so this is cheap.
feat_out = asr.processor.feature_extractor(
[state.audio_buffer],
sampling_rate=16000,
padding=True,
truncation=False,
return_attention_mask=True,
return_tensors="pt",
)
input_features = feat_out["input_features"].to(asr.device).to(asr.dtype)
feature_attention_mask = feat_out["attention_mask"].to(asr.device)
total_mel_frames = feature_attention_mask.sum().item()
# Step 2: Compute total audio tokens for the full audio.
total_audio_tokens = _get_feat_extract_output_lengths(
torch.tensor(total_mel_frames),
).item()
# Step 3: Determine how many tokens are in stable (complete) windows.
# The encoder processes mel in chunks of n_window*2 (200 frames).
# Attention windows span n_window_infer (400 frames) = 2 chunks.
# A window is "complete" if it has a full n_window_infer mel frames.
audio_cfg = asr.model.thinker.audio_tower.config
n_window_infer = getattr(audio_cfg, "n_window_infer", 400)
# Number of complete attention windows
n_complete_windows = total_mel_frames // n_window_infer
if n_complete_windows <= 0:
# Audio is shorter than one window -- no stable prefix to cache.
# Encode the full audio and cache it (all unstable).
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
# Update cache for next call
cache.embeddings = audio_embeds.unsqueeze(0) if audio_embeds.dim() == 2 else audio_embeds
cache.encoded_samples = len(state.audio_buffer)
cache.encoded_mel_frames = total_mel_frames
cache.stable_tokens = 0
return cache.embeddings[0] if cache.embeddings.dim() == 3 else cache.embeddings
# Mel frames in the stable prefix (all complete windows)
stable_mel = n_complete_windows * n_window_infer
stable_tokens = _get_feat_extract_output_lengths(
torch.tensor(stable_mel),
).item()
# Step 4: Check if we have a valid cache for the stable prefix.
# The cache is valid if:
# - We have cached embeddings
# - The number of stable tokens in the cache matches (or exceeds)
# the current stable prefix
# - The audio buffer hasn't been modified before the cached region
can_reuse = (
cache.embeddings is not None
and cache.stable_tokens > 0
and cache.stable_tokens <= stable_tokens
# The encoded_samples tells us how much audio the cache covers.
# If the current buffer starts with the same audio, the prefix
# embeddings are still valid.
and cache.encoded_samples <= len(state.audio_buffer)
)
if can_reuse and cache.stable_tokens == stable_tokens:
# The stable prefix hasn't changed -- reuse cached embeddings
# for the stable part, only re-encode the tail.
cached_prefix = cache.embeddings[0, :stable_tokens] if cache.embeddings.dim() == 3 else cache.embeddings[:stable_tokens]
# Encode only the tail (from stable_mel onward)
tail_mel_start = stable_mel
tail_features = input_features[:, :, tail_mel_start:]
tail_mel_frames = total_mel_frames - tail_mel_start
if tail_mel_frames > 0:
tail_mask = torch.ones(
(1, tail_features.shape[2]),
dtype=feature_attention_mask.dtype,
device=feature_attention_mask.device,
)
tail_embeds = asr.model.thinker.get_audio_features(
tail_features, feature_attention_mask=tail_mask,
)
# get_audio_features returns (n_tokens, hidden_dim)
if tail_embeds.dim() == 3:
tail_embeds = tail_embeds[0]
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
else:
audio_embeds = cached_prefix
logger.info(
"Audio cache HIT: reused %d/%d tokens, re-encoded %d tail tokens",
stable_tokens, total_audio_tokens,
total_audio_tokens - stable_tokens,
)
else:
# Cache miss or stale -- encode the full audio
audio_embeds = asr.model.thinker.get_audio_features(
input_features, feature_attention_mask=feature_attention_mask,
)
if audio_embeds.dim() == 3:
audio_embeds = audio_embeds[0]
logger.info(
"Audio cache MISS: encoded full %d tokens (was: %d stable cached)",
total_audio_tokens, cache.stable_tokens if cache.embeddings is not None else 0,
)
# Step 5: Update cache for next call.
cache.embeddings = audio_embeds.unsqueeze(0) # (1, n_tokens, hidden)
cache.encoded_samples = len(state.audio_buffer)
cache.encoded_mel_frames = total_mel_frames
cache.stable_tokens = stable_tokens
return audio_embeds # (n_tokens, hidden_dim)
except Exception as e:
logger.warning("Audio cache encoding failed, falling back: %s", e)
cache.reset()
return None
def _build_inputs_with_cached_audio(
self, audio_embeds: torch.Tensor,
) -> Optional[dict]:
"""Build generate() inputs using pre-computed audio embeddings.
Instead of passing ``input_features`` (which triggers the audio encoder
inside the model's forward), we:
1. Tokenize the text prompt to get ``input_ids``
2. Embed the text tokens via ``get_input_embeddings()``
3. Replace audio placeholder positions with ``audio_embeds``
4. Append committed context token embeddings
5. Return ``inputs_embeds`` + ``attention_mask`` (no ``input_ids``,
no ``input_features``)
Returns None if the construction fails (caller falls back).
"""
asr = self.asr
state = self.state
thinker = asr.model.thinker
try:
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
_get_feat_extract_output_lengths,
)
n_audio_tokens = audio_embeds.shape[0]
# Tokenize the text prompt with the correct number of audio
# placeholder tokens. The processor's
# ``replace_multimodal_special_tokens`` expands the single
# <|audio_pad|> into the right count.
prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens(
[self._base_prompt],
iter([n_audio_tokens]),
)[0]
text_ids = asr.processor.tokenizer(
[prompt_with_placeholders],
return_tensors="pt",
padding=True,
)
input_ids = text_ids["input_ids"].to(asr.device)
attention_mask = text_ids.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(asr.device)
# Append committed context tokens
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
ctx_ids = torch.tensor(
[ctx], dtype=input_ids.dtype, device=input_ids.device,
)
input_ids = torch.cat([input_ids, ctx_ids], dim=1)
if attention_mask is not None:
ctx_mask = torch.ones_like(ctx_ids)
attention_mask = torch.cat([attention_mask, ctx_mask], dim=1)
# Build inputs_embeds: embed text tokens, then scatter audio embeds
inputs_embeds = thinker.get_input_embeddings()(input_ids)
# Find audio placeholder positions
audio_mask = (input_ids == asr.audio_token_id)
n_placeholders = audio_mask.sum().item()
if n_placeholders != n_audio_tokens:
logger.warning(
"Audio token mismatch: %d placeholders vs %d embeddings",
n_placeholders, n_audio_tokens,
)
return None
# Scatter audio embeddings into placeholder positions
audio_embeds_for_scatter = audio_embeds.to(
inputs_embeds.device, inputs_embeds.dtype,
)
expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(
expand_mask, audio_embeds_for_scatter,
)
result = {
"inputs_embeds": inputs_embeds,
"input_ids": input_ids, # needed for position_ids/rope computation
}
if attention_mask is not None:
result["attention_mask"] = attention_mask
return result
except Exception as e:
logger.warning("Failed to build inputs with cached audio: %s", e)
return None
@torch.inference_mode()
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
"""
@ -405,7 +710,8 @@ class Qwen3SimulStreamingOnlineProcessor:
return [], self.end
# Throttle: skip inference if less than 1s of new audio since last run.
# Each inference re-encodes the full buffer, so calling too often wastes
# Audio embedding caching avoids re-encoding the stable prefix, but
# the decoder still runs a full prefill, so calling too often wastes
# GPU/CPU time and causes lag to spiral.
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
min_new_seconds = 1.0
@ -435,11 +741,50 @@ class Qwen3SimulStreamingOnlineProcessor:
during generation. The Qwen3-ASR decoder layer discards attention
weights (hidden_states, _ = self.self_attn(...)), so output_attentions
via generate() would return None. Hooks capture them before discard.
Audio embedding caching: instead of re-encoding the entire audio buffer
through the audio encoder on every call, we cache embeddings for the
stable prefix (complete attention windows) and only re-encode the tail.
This reduces the audio encoding cost from O(n) to O(1) per call for
the stable prefix, changing overall complexity from O(n^2) to O(n).
"""
asr = self.asr
state = self.state
# Prepare inputs
# --- Prepare inputs (with audio embedding cache) ---
#
# Try the cached path first: encode audio incrementally, then build
# inputs_embeds directly. If anything fails, fall back to the original
# processor-based path.
use_cached_path = False
audio_embeds = self._encode_audio_cached()
if audio_embeds is not None:
cached_inputs = self._build_inputs_with_cached_audio(audio_embeds)
if cached_inputs is not None:
input_ids_for_pos = cached_inputs["input_ids"]
inputs_embeds = cached_inputs["inputs_embeds"]
# Build the inputs dict for generate().
# We pass BOTH input_ids and inputs_embeds. The model's forward()
# checks: if inputs_embeds is not None, it skips embedding lookup.
# But input_ids is still needed for:
# - Finding audio placeholder positions (get_placeholder_mask)
# - Computing position_ids / rope_deltas
# We set input_features=None so the model does NOT re-run the
# audio encoder.
inputs = {
"input_ids": input_ids_for_pos,
"inputs_embeds": inputs_embeds,
"attention_mask": cached_inputs.get("attention_mask"),
}
# Remove None values
inputs = {k: v for k, v in inputs.items() if v is not None}
use_cached_path = True
if not use_cached_path:
# Fallback: original processor-based path (full re-encoding)
logger.info("Using fallback (non-cached) audio encoding path")
state.audio_cache.reset()
inputs = asr.processor(
text=[self._base_prompt],
audio=[state.audio_buffer],
@ -448,8 +793,7 @@ class Qwen3SimulStreamingOnlineProcessor:
)
inputs = inputs.to(asr.device).to(asr.dtype)
# Append committed token IDs as context so generate() continues from
# where it left off. Cap at max_context_tokens to prevent prompt growth.
# Append committed token IDs as context
if state.committed_token_ids:
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
ctx_ids = torch.tensor(
@ -463,11 +807,20 @@ class Qwen3SimulStreamingOnlineProcessor:
[inputs.attention_mask, ctx_mask], dim=1,
)
# prompt_len = number of tokens in the input sequence (for slicing
# generated tokens from the output). generate() constructs output
# starting from input_ids, so use input_ids.shape[1] in both paths.
if use_cached_path:
prompt_len = inputs["input_ids"].shape[1]
else:
prompt_len = inputs.input_ids.shape[1]
# Find audio token range
input_ids = inputs.input_ids[0]
audio_mask = (input_ids == asr.audio_token_id)
# Find audio token range from input_ids
if use_cached_path:
ids_for_audio_range = inputs["input_ids"][0]
else:
ids_for_audio_range = inputs.input_ids[0]
audio_mask = (ids_for_audio_range == asr.audio_token_id)
audio_positions = audio_mask.nonzero(as_tuple=True)[0]
if len(audio_positions) == 0:
return []