qwen3: reuse encoder kv cache
This commit is contained in:
parent
f24481dc29
commit
dd48997674
1 changed files with 378 additions and 25 deletions
|
|
@ -66,6 +66,64 @@ class Qwen3SimulConfig:
|
|||
max_alignment_heads: int = 20 # Use only top N alignment heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class _AudioEmbedCache:
|
||||
"""Cached audio encoder outputs for incremental encoding.
|
||||
|
||||
The Qwen3-ASR audio encoder processes mel features in chunks of
|
||||
``n_window * 2`` mel frames with windowed self-attention spanning
|
||||
``n_window_infer`` mel frames (800 for both 0.6B and 1.7B = 8s of
|
||||
audio). Within one attention window chunks can attend to each other,
|
||||
but across windows they cannot.
|
||||
|
||||
We cache the audio embeddings (output of ``get_audio_features``) for
|
||||
all *complete attention windows* whose input mel frames are unchanged.
|
||||
When the audio buffer grows, only the tail (last incomplete window +
|
||||
new audio) is re-encoded through the audio encoder, and the result is
|
||||
concatenated with the cached prefix.
|
||||
|
||||
When the audio buffer is trimmed from the front (e.g. max_len exceeded),
|
||||
the cache is fully invalidated and rebuilt on the next call.
|
||||
"""
|
||||
# Number of audio *samples* (PCM @ 16kHz) that have been fully encoded.
|
||||
# This always equals the number of samples whose mel features were fed
|
||||
# to the audio encoder for the cached embeddings.
|
||||
encoded_samples: int = 0
|
||||
|
||||
# Cached audio embeddings tensor, shape (1, n_cached_tokens, hidden_dim).
|
||||
# None means "no cache yet".
|
||||
embeddings: Optional[torch.Tensor] = None
|
||||
|
||||
# Number of mel frames that produced ``embeddings``.
|
||||
# Used to verify cache validity (mel length must match).
|
||||
encoded_mel_frames: int = 0
|
||||
|
||||
# Number of audio tokens (embeddings.shape[1]) that are in *complete*
|
||||
# attention windows and can be safely reused. Tokens from the last
|
||||
# (potentially incomplete) window are always re-encoded.
|
||||
stable_tokens: int = 0
|
||||
|
||||
def trim_front(self, trim_samples: int, sample_rate: int = 16000):
|
||||
"""Invalidate cache entries for audio trimmed from the front.
|
||||
|
||||
Called when ``insert_audio_chunk`` trims the buffer. Rather than
|
||||
attempting complex partial invalidation (which could introduce subtle
|
||||
bugs if the mel/token math doesn't align perfectly), we simply reset
|
||||
the cache. The next ``_encode_audio_cached`` call will rebuild it.
|
||||
|
||||
This is safe because trimming only happens when the audio buffer
|
||||
exceeds ``audio_max_len`` (~15s), which is relatively infrequent.
|
||||
"""
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Fully invalidate the cache."""
|
||||
self.encoded_samples = 0
|
||||
self.embeddings = None
|
||||
self.encoded_mel_frames = 0
|
||||
self.stable_tokens = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen3SimulState:
|
||||
"""Per-session mutable state for Qwen3 SimulStreaming."""
|
||||
|
|
@ -89,6 +147,9 @@ class Qwen3SimulState:
|
|||
detected_language: Optional[str] = None
|
||||
last_infer_samples: int = 0 # audio_buffer length at last inference
|
||||
|
||||
# Audio embedding cache for incremental encoding
|
||||
audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache)
|
||||
|
||||
|
||||
class Qwen3SimulStreamingASR:
|
||||
"""
|
||||
|
|
@ -346,6 +407,8 @@ class Qwen3SimulStreamingOnlineProcessor:
|
|||
self.state.cumulative_time_offset += trim / self.SAMPLING_RATE
|
||||
# Adjust throttle counter so it tracks position within the trimmed buffer
|
||||
self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim)
|
||||
# Trim audio embedding cache to match
|
||||
self.state.audio_cache.trim_front(trim, self.SAMPLING_RATE)
|
||||
|
||||
def start_silence(self) -> Tuple[List[ASRToken], float]:
|
||||
"""Handle start of silence -- flush all pending tokens.
|
||||
|
|
@ -388,6 +451,248 @@ class Qwen3SimulStreamingOnlineProcessor:
|
|||
"""Get the current unvalidated buffer."""
|
||||
return Transcript.from_tokens(tokens=self.buffer, sep='')
|
||||
|
||||
def _encode_audio_cached(self) -> Optional[torch.Tensor]:
|
||||
"""Encode audio buffer using cached embeddings where possible.
|
||||
|
||||
Returns the full audio embeddings tensor (n_audio_tokens, hidden_dim),
|
||||
or None if caching is not possible (caller should fall back to the
|
||||
processor-based path).
|
||||
|
||||
Caching strategy:
|
||||
- The audio encoder uses windowed attention with window size
|
||||
``n_window_infer`` (800 mel frames = 8s of audio for both the
|
||||
0.6B and 1.7B models).
|
||||
- Tokens within one window can attend to each other, but not across
|
||||
windows. So all tokens in *complete* windows are deterministic
|
||||
and can be cached.
|
||||
- We only re-encode the *tail* of the audio (from the last complete
|
||||
window boundary onward) through the audio encoder.
|
||||
- The cached prefix embeddings are concatenated with the new tail
|
||||
embeddings to produce the full result.
|
||||
"""
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
cache = state.audio_cache
|
||||
|
||||
if len(state.audio_buffer) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
# Step 1: Compute mel features for the FULL audio.
|
||||
# WhisperFeatureExtractor is fast (CPU FFT), so this is cheap.
|
||||
feat_out = asr.processor.feature_extractor(
|
||||
[state.audio_buffer],
|
||||
sampling_rate=16000,
|
||||
padding=True,
|
||||
truncation=False,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_features = feat_out["input_features"].to(asr.device).to(asr.dtype)
|
||||
feature_attention_mask = feat_out["attention_mask"].to(asr.device)
|
||||
total_mel_frames = feature_attention_mask.sum().item()
|
||||
|
||||
# Step 2: Compute total audio tokens for the full audio.
|
||||
total_audio_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(total_mel_frames),
|
||||
).item()
|
||||
|
||||
# Step 3: Determine how many tokens are in stable (complete) windows.
|
||||
# The encoder processes mel in chunks of n_window*2 (200 frames).
|
||||
# Attention windows span n_window_infer (400 frames) = 2 chunks.
|
||||
# A window is "complete" if it has a full n_window_infer mel frames.
|
||||
audio_cfg = asr.model.thinker.audio_tower.config
|
||||
n_window_infer = getattr(audio_cfg, "n_window_infer", 400)
|
||||
|
||||
# Number of complete attention windows
|
||||
n_complete_windows = total_mel_frames // n_window_infer
|
||||
|
||||
if n_complete_windows <= 0:
|
||||
# Audio is shorter than one window -- no stable prefix to cache.
|
||||
# Encode the full audio and cache it (all unstable).
|
||||
audio_embeds = asr.model.thinker.get_audio_features(
|
||||
input_features, feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
# Update cache for next call
|
||||
cache.embeddings = audio_embeds.unsqueeze(0) if audio_embeds.dim() == 2 else audio_embeds
|
||||
cache.encoded_samples = len(state.audio_buffer)
|
||||
cache.encoded_mel_frames = total_mel_frames
|
||||
cache.stable_tokens = 0
|
||||
return cache.embeddings[0] if cache.embeddings.dim() == 3 else cache.embeddings
|
||||
|
||||
# Mel frames in the stable prefix (all complete windows)
|
||||
stable_mel = n_complete_windows * n_window_infer
|
||||
stable_tokens = _get_feat_extract_output_lengths(
|
||||
torch.tensor(stable_mel),
|
||||
).item()
|
||||
|
||||
# Step 4: Check if we have a valid cache for the stable prefix.
|
||||
# The cache is valid if:
|
||||
# - We have cached embeddings
|
||||
# - The number of stable tokens in the cache matches (or exceeds)
|
||||
# the current stable prefix
|
||||
# - The audio buffer hasn't been modified before the cached region
|
||||
can_reuse = (
|
||||
cache.embeddings is not None
|
||||
and cache.stable_tokens > 0
|
||||
and cache.stable_tokens <= stable_tokens
|
||||
# The encoded_samples tells us how much audio the cache covers.
|
||||
# If the current buffer starts with the same audio, the prefix
|
||||
# embeddings are still valid.
|
||||
and cache.encoded_samples <= len(state.audio_buffer)
|
||||
)
|
||||
|
||||
if can_reuse and cache.stable_tokens == stable_tokens:
|
||||
# The stable prefix hasn't changed -- reuse cached embeddings
|
||||
# for the stable part, only re-encode the tail.
|
||||
cached_prefix = cache.embeddings[0, :stable_tokens] if cache.embeddings.dim() == 3 else cache.embeddings[:stable_tokens]
|
||||
|
||||
# Encode only the tail (from stable_mel onward)
|
||||
tail_mel_start = stable_mel
|
||||
tail_features = input_features[:, :, tail_mel_start:]
|
||||
tail_mel_frames = total_mel_frames - tail_mel_start
|
||||
if tail_mel_frames > 0:
|
||||
tail_mask = torch.ones(
|
||||
(1, tail_features.shape[2]),
|
||||
dtype=feature_attention_mask.dtype,
|
||||
device=feature_attention_mask.device,
|
||||
)
|
||||
tail_embeds = asr.model.thinker.get_audio_features(
|
||||
tail_features, feature_attention_mask=tail_mask,
|
||||
)
|
||||
# get_audio_features returns (n_tokens, hidden_dim)
|
||||
if tail_embeds.dim() == 3:
|
||||
tail_embeds = tail_embeds[0]
|
||||
audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0)
|
||||
else:
|
||||
audio_embeds = cached_prefix
|
||||
|
||||
logger.info(
|
||||
"Audio cache HIT: reused %d/%d tokens, re-encoded %d tail tokens",
|
||||
stable_tokens, total_audio_tokens,
|
||||
total_audio_tokens - stable_tokens,
|
||||
)
|
||||
else:
|
||||
# Cache miss or stale -- encode the full audio
|
||||
audio_embeds = asr.model.thinker.get_audio_features(
|
||||
input_features, feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
if audio_embeds.dim() == 3:
|
||||
audio_embeds = audio_embeds[0]
|
||||
logger.info(
|
||||
"Audio cache MISS: encoded full %d tokens (was: %d stable cached)",
|
||||
total_audio_tokens, cache.stable_tokens if cache.embeddings is not None else 0,
|
||||
)
|
||||
|
||||
# Step 5: Update cache for next call.
|
||||
cache.embeddings = audio_embeds.unsqueeze(0) # (1, n_tokens, hidden)
|
||||
cache.encoded_samples = len(state.audio_buffer)
|
||||
cache.encoded_mel_frames = total_mel_frames
|
||||
cache.stable_tokens = stable_tokens
|
||||
|
||||
return audio_embeds # (n_tokens, hidden_dim)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Audio cache encoding failed, falling back: %s", e)
|
||||
cache.reset()
|
||||
return None
|
||||
|
||||
def _build_inputs_with_cached_audio(
|
||||
self, audio_embeds: torch.Tensor,
|
||||
) -> Optional[dict]:
|
||||
"""Build generate() inputs using pre-computed audio embeddings.
|
||||
|
||||
Instead of passing ``input_features`` (which triggers the audio encoder
|
||||
inside the model's forward), we:
|
||||
1. Tokenize the text prompt to get ``input_ids``
|
||||
2. Embed the text tokens via ``get_input_embeddings()``
|
||||
3. Replace audio placeholder positions with ``audio_embeds``
|
||||
4. Append committed context token embeddings
|
||||
5. Return ``inputs_embeds`` + ``attention_mask`` (no ``input_ids``,
|
||||
no ``input_features``)
|
||||
|
||||
Returns None if the construction fails (caller falls back).
|
||||
"""
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
thinker = asr.model.thinker
|
||||
|
||||
try:
|
||||
from qwen_asr.core.transformers_backend.processing_qwen3_asr import (
|
||||
_get_feat_extract_output_lengths,
|
||||
)
|
||||
|
||||
n_audio_tokens = audio_embeds.shape[0]
|
||||
|
||||
# Tokenize the text prompt with the correct number of audio
|
||||
# placeholder tokens. The processor's
|
||||
# ``replace_multimodal_special_tokens`` expands the single
|
||||
# <|audio_pad|> into the right count.
|
||||
prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens(
|
||||
[self._base_prompt],
|
||||
iter([n_audio_tokens]),
|
||||
)[0]
|
||||
text_ids = asr.processor.tokenizer(
|
||||
[prompt_with_placeholders],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
input_ids = text_ids["input_ids"].to(asr.device)
|
||||
attention_mask = text_ids.get("attention_mask")
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(asr.device)
|
||||
|
||||
# Append committed context tokens
|
||||
if state.committed_token_ids:
|
||||
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
|
||||
ctx_ids = torch.tensor(
|
||||
[ctx], dtype=input_ids.dtype, device=input_ids.device,
|
||||
)
|
||||
input_ids = torch.cat([input_ids, ctx_ids], dim=1)
|
||||
if attention_mask is not None:
|
||||
ctx_mask = torch.ones_like(ctx_ids)
|
||||
attention_mask = torch.cat([attention_mask, ctx_mask], dim=1)
|
||||
|
||||
# Build inputs_embeds: embed text tokens, then scatter audio embeds
|
||||
inputs_embeds = thinker.get_input_embeddings()(input_ids)
|
||||
|
||||
# Find audio placeholder positions
|
||||
audio_mask = (input_ids == asr.audio_token_id)
|
||||
n_placeholders = audio_mask.sum().item()
|
||||
|
||||
if n_placeholders != n_audio_tokens:
|
||||
logger.warning(
|
||||
"Audio token mismatch: %d placeholders vs %d embeddings",
|
||||
n_placeholders, n_audio_tokens,
|
||||
)
|
||||
return None
|
||||
|
||||
# Scatter audio embeddings into placeholder positions
|
||||
audio_embeds_for_scatter = audio_embeds.to(
|
||||
inputs_embeds.device, inputs_embeds.dtype,
|
||||
)
|
||||
expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expand_mask, audio_embeds_for_scatter,
|
||||
)
|
||||
|
||||
result = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"input_ids": input_ids, # needed for position_ids/rope computation
|
||||
}
|
||||
if attention_mask is not None:
|
||||
result["attention_mask"] = attention_mask
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to build inputs with cached audio: %s", e)
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]:
|
||||
"""
|
||||
|
|
@ -405,7 +710,8 @@ class Qwen3SimulStreamingOnlineProcessor:
|
|||
return [], self.end
|
||||
|
||||
# Throttle: skip inference if less than 1s of new audio since last run.
|
||||
# Each inference re-encodes the full buffer, so calling too often wastes
|
||||
# Audio embedding caching avoids re-encoding the stable prefix, but
|
||||
# the decoder still runs a full prefill, so calling too often wastes
|
||||
# GPU/CPU time and causes lag to spiral.
|
||||
new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples
|
||||
min_new_seconds = 1.0
|
||||
|
|
@ -435,39 +741,86 @@ class Qwen3SimulStreamingOnlineProcessor:
|
|||
during generation. The Qwen3-ASR decoder layer discards attention
|
||||
weights (hidden_states, _ = self.self_attn(...)), so output_attentions
|
||||
via generate() would return None. Hooks capture them before discard.
|
||||
|
||||
Audio embedding caching: instead of re-encoding the entire audio buffer
|
||||
through the audio encoder on every call, we cache embeddings for the
|
||||
stable prefix (complete attention windows) and only re-encode the tail.
|
||||
This reduces the audio encoding cost from O(n) to O(1) per call for
|
||||
the stable prefix, changing overall complexity from O(n^2) to O(n).
|
||||
"""
|
||||
asr = self.asr
|
||||
state = self.state
|
||||
|
||||
# Prepare inputs
|
||||
inputs = asr.processor(
|
||||
text=[self._base_prompt],
|
||||
audio=[state.audio_buffer],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
inputs = inputs.to(asr.device).to(asr.dtype)
|
||||
# --- Prepare inputs (with audio embedding cache) ---
|
||||
#
|
||||
# Try the cached path first: encode audio incrementally, then build
|
||||
# inputs_embeds directly. If anything fails, fall back to the original
|
||||
# processor-based path.
|
||||
use_cached_path = False
|
||||
audio_embeds = self._encode_audio_cached()
|
||||
if audio_embeds is not None:
|
||||
cached_inputs = self._build_inputs_with_cached_audio(audio_embeds)
|
||||
if cached_inputs is not None:
|
||||
input_ids_for_pos = cached_inputs["input_ids"]
|
||||
inputs_embeds = cached_inputs["inputs_embeds"]
|
||||
|
||||
# Append committed token IDs as context so generate() continues from
|
||||
# where it left off. Cap at max_context_tokens to prevent prompt growth.
|
||||
if state.committed_token_ids:
|
||||
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
|
||||
ctx_ids = torch.tensor(
|
||||
[ctx], dtype=inputs.input_ids.dtype,
|
||||
device=inputs.input_ids.device,
|
||||
# Build the inputs dict for generate().
|
||||
# We pass BOTH input_ids and inputs_embeds. The model's forward()
|
||||
# checks: if inputs_embeds is not None, it skips embedding lookup.
|
||||
# But input_ids is still needed for:
|
||||
# - Finding audio placeholder positions (get_placeholder_mask)
|
||||
# - Computing position_ids / rope_deltas
|
||||
# We set input_features=None so the model does NOT re-run the
|
||||
# audio encoder.
|
||||
inputs = {
|
||||
"input_ids": input_ids_for_pos,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"attention_mask": cached_inputs.get("attention_mask"),
|
||||
}
|
||||
# Remove None values
|
||||
inputs = {k: v for k, v in inputs.items() if v is not None}
|
||||
use_cached_path = True
|
||||
|
||||
if not use_cached_path:
|
||||
# Fallback: original processor-based path (full re-encoding)
|
||||
logger.info("Using fallback (non-cached) audio encoding path")
|
||||
state.audio_cache.reset()
|
||||
inputs = asr.processor(
|
||||
text=[self._base_prompt],
|
||||
audio=[state.audio_buffer],
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
)
|
||||
inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1)
|
||||
if "attention_mask" in inputs:
|
||||
ctx_mask = torch.ones_like(ctx_ids)
|
||||
inputs["attention_mask"] = torch.cat(
|
||||
[inputs.attention_mask, ctx_mask], dim=1,
|
||||
inputs = inputs.to(asr.device).to(asr.dtype)
|
||||
|
||||
# Append committed token IDs as context
|
||||
if state.committed_token_ids:
|
||||
ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:]
|
||||
ctx_ids = torch.tensor(
|
||||
[ctx], dtype=inputs.input_ids.dtype,
|
||||
device=inputs.input_ids.device,
|
||||
)
|
||||
inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1)
|
||||
if "attention_mask" in inputs:
|
||||
ctx_mask = torch.ones_like(ctx_ids)
|
||||
inputs["attention_mask"] = torch.cat(
|
||||
[inputs.attention_mask, ctx_mask], dim=1,
|
||||
)
|
||||
|
||||
prompt_len = inputs.input_ids.shape[1]
|
||||
# prompt_len = number of tokens in the input sequence (for slicing
|
||||
# generated tokens from the output). generate() constructs output
|
||||
# starting from input_ids, so use input_ids.shape[1] in both paths.
|
||||
if use_cached_path:
|
||||
prompt_len = inputs["input_ids"].shape[1]
|
||||
else:
|
||||
prompt_len = inputs.input_ids.shape[1]
|
||||
|
||||
# Find audio token range
|
||||
input_ids = inputs.input_ids[0]
|
||||
audio_mask = (input_ids == asr.audio_token_id)
|
||||
# Find audio token range from input_ids
|
||||
if use_cached_path:
|
||||
ids_for_audio_range = inputs["input_ids"][0]
|
||||
else:
|
||||
ids_for_audio_range = inputs.input_ids[0]
|
||||
audio_mask = (ids_for_audio_range == asr.audio_token_id)
|
||||
audio_positions = audio_mask.nonzero(as_tuple=True)[0]
|
||||
if len(audio_positions) == 0:
|
||||
return []
|
||||
|
|
|
|||
Loading…
Reference in a new issue