From dd4899767406d3fb84962f66749f26c2acd50178 Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Sun, 15 Mar 2026 22:31:39 +0100 Subject: [PATCH] qwen3: reuse encoder kv cache --- whisperlivekit/qwen3_simul.py | 403 +++++++++++++++++++++++++++++++--- 1 file changed, 378 insertions(+), 25 deletions(-) diff --git a/whisperlivekit/qwen3_simul.py b/whisperlivekit/qwen3_simul.py index eaac07e..0adec29 100644 --- a/whisperlivekit/qwen3_simul.py +++ b/whisperlivekit/qwen3_simul.py @@ -66,6 +66,64 @@ class Qwen3SimulConfig: max_alignment_heads: int = 20 # Use only top N alignment heads +@dataclass +class _AudioEmbedCache: + """Cached audio encoder outputs for incremental encoding. + + The Qwen3-ASR audio encoder processes mel features in chunks of + ``n_window * 2`` mel frames with windowed self-attention spanning + ``n_window_infer`` mel frames (800 for both 0.6B and 1.7B = 8s of + audio). Within one attention window chunks can attend to each other, + but across windows they cannot. + + We cache the audio embeddings (output of ``get_audio_features``) for + all *complete attention windows* whose input mel frames are unchanged. + When the audio buffer grows, only the tail (last incomplete window + + new audio) is re-encoded through the audio encoder, and the result is + concatenated with the cached prefix. + + When the audio buffer is trimmed from the front (e.g. max_len exceeded), + the cache is fully invalidated and rebuilt on the next call. + """ + # Number of audio *samples* (PCM @ 16kHz) that have been fully encoded. + # This always equals the number of samples whose mel features were fed + # to the audio encoder for the cached embeddings. + encoded_samples: int = 0 + + # Cached audio embeddings tensor, shape (1, n_cached_tokens, hidden_dim). + # None means "no cache yet". + embeddings: Optional[torch.Tensor] = None + + # Number of mel frames that produced ``embeddings``. + # Used to verify cache validity (mel length must match). + encoded_mel_frames: int = 0 + + # Number of audio tokens (embeddings.shape[1]) that are in *complete* + # attention windows and can be safely reused. Tokens from the last + # (potentially incomplete) window are always re-encoded. + stable_tokens: int = 0 + + def trim_front(self, trim_samples: int, sample_rate: int = 16000): + """Invalidate cache entries for audio trimmed from the front. + + Called when ``insert_audio_chunk`` trims the buffer. Rather than + attempting complex partial invalidation (which could introduce subtle + bugs if the mel/token math doesn't align perfectly), we simply reset + the cache. The next ``_encode_audio_cached`` call will rebuild it. + + This is safe because trimming only happens when the audio buffer + exceeds ``audio_max_len`` (~15s), which is relatively infrequent. + """ + self.reset() + + def reset(self): + """Fully invalidate the cache.""" + self.encoded_samples = 0 + self.embeddings = None + self.encoded_mel_frames = 0 + self.stable_tokens = 0 + + @dataclass class Qwen3SimulState: """Per-session mutable state for Qwen3 SimulStreaming.""" @@ -89,6 +147,9 @@ class Qwen3SimulState: detected_language: Optional[str] = None last_infer_samples: int = 0 # audio_buffer length at last inference + # Audio embedding cache for incremental encoding + audio_cache: _AudioEmbedCache = field(default_factory=_AudioEmbedCache) + class Qwen3SimulStreamingASR: """ @@ -346,6 +407,8 @@ class Qwen3SimulStreamingOnlineProcessor: self.state.cumulative_time_offset += trim / self.SAMPLING_RATE # Adjust throttle counter so it tracks position within the trimmed buffer self.state.last_infer_samples = max(0, self.state.last_infer_samples - trim) + # Trim audio embedding cache to match + self.state.audio_cache.trim_front(trim, self.SAMPLING_RATE) def start_silence(self) -> Tuple[List[ASRToken], float]: """Handle start of silence -- flush all pending tokens. @@ -388,6 +451,248 @@ class Qwen3SimulStreamingOnlineProcessor: """Get the current unvalidated buffer.""" return Transcript.from_tokens(tokens=self.buffer, sep='') + def _encode_audio_cached(self) -> Optional[torch.Tensor]: + """Encode audio buffer using cached embeddings where possible. + + Returns the full audio embeddings tensor (n_audio_tokens, hidden_dim), + or None if caching is not possible (caller should fall back to the + processor-based path). + + Caching strategy: + - The audio encoder uses windowed attention with window size + ``n_window_infer`` (800 mel frames = 8s of audio for both the + 0.6B and 1.7B models). + - Tokens within one window can attend to each other, but not across + windows. So all tokens in *complete* windows are deterministic + and can be cached. + - We only re-encode the *tail* of the audio (from the last complete + window boundary onward) through the audio encoder. + - The cached prefix embeddings are concatenated with the new tail + embeddings to produce the full result. + """ + asr = self.asr + state = self.state + cache = state.audio_cache + + if len(state.audio_buffer) == 0: + return None + + try: + from qwen_asr.core.transformers_backend.processing_qwen3_asr import ( + _get_feat_extract_output_lengths, + ) + + # Step 1: Compute mel features for the FULL audio. + # WhisperFeatureExtractor is fast (CPU FFT), so this is cheap. + feat_out = asr.processor.feature_extractor( + [state.audio_buffer], + sampling_rate=16000, + padding=True, + truncation=False, + return_attention_mask=True, + return_tensors="pt", + ) + input_features = feat_out["input_features"].to(asr.device).to(asr.dtype) + feature_attention_mask = feat_out["attention_mask"].to(asr.device) + total_mel_frames = feature_attention_mask.sum().item() + + # Step 2: Compute total audio tokens for the full audio. + total_audio_tokens = _get_feat_extract_output_lengths( + torch.tensor(total_mel_frames), + ).item() + + # Step 3: Determine how many tokens are in stable (complete) windows. + # The encoder processes mel in chunks of n_window*2 (200 frames). + # Attention windows span n_window_infer (400 frames) = 2 chunks. + # A window is "complete" if it has a full n_window_infer mel frames. + audio_cfg = asr.model.thinker.audio_tower.config + n_window_infer = getattr(audio_cfg, "n_window_infer", 400) + + # Number of complete attention windows + n_complete_windows = total_mel_frames // n_window_infer + + if n_complete_windows <= 0: + # Audio is shorter than one window -- no stable prefix to cache. + # Encode the full audio and cache it (all unstable). + audio_embeds = asr.model.thinker.get_audio_features( + input_features, feature_attention_mask=feature_attention_mask, + ) + # Update cache for next call + cache.embeddings = audio_embeds.unsqueeze(0) if audio_embeds.dim() == 2 else audio_embeds + cache.encoded_samples = len(state.audio_buffer) + cache.encoded_mel_frames = total_mel_frames + cache.stable_tokens = 0 + return cache.embeddings[0] if cache.embeddings.dim() == 3 else cache.embeddings + + # Mel frames in the stable prefix (all complete windows) + stable_mel = n_complete_windows * n_window_infer + stable_tokens = _get_feat_extract_output_lengths( + torch.tensor(stable_mel), + ).item() + + # Step 4: Check if we have a valid cache for the stable prefix. + # The cache is valid if: + # - We have cached embeddings + # - The number of stable tokens in the cache matches (or exceeds) + # the current stable prefix + # - The audio buffer hasn't been modified before the cached region + can_reuse = ( + cache.embeddings is not None + and cache.stable_tokens > 0 + and cache.stable_tokens <= stable_tokens + # The encoded_samples tells us how much audio the cache covers. + # If the current buffer starts with the same audio, the prefix + # embeddings are still valid. + and cache.encoded_samples <= len(state.audio_buffer) + ) + + if can_reuse and cache.stable_tokens == stable_tokens: + # The stable prefix hasn't changed -- reuse cached embeddings + # for the stable part, only re-encode the tail. + cached_prefix = cache.embeddings[0, :stable_tokens] if cache.embeddings.dim() == 3 else cache.embeddings[:stable_tokens] + + # Encode only the tail (from stable_mel onward) + tail_mel_start = stable_mel + tail_features = input_features[:, :, tail_mel_start:] + tail_mel_frames = total_mel_frames - tail_mel_start + if tail_mel_frames > 0: + tail_mask = torch.ones( + (1, tail_features.shape[2]), + dtype=feature_attention_mask.dtype, + device=feature_attention_mask.device, + ) + tail_embeds = asr.model.thinker.get_audio_features( + tail_features, feature_attention_mask=tail_mask, + ) + # get_audio_features returns (n_tokens, hidden_dim) + if tail_embeds.dim() == 3: + tail_embeds = tail_embeds[0] + audio_embeds = torch.cat([cached_prefix, tail_embeds], dim=0) + else: + audio_embeds = cached_prefix + + logger.info( + "Audio cache HIT: reused %d/%d tokens, re-encoded %d tail tokens", + stable_tokens, total_audio_tokens, + total_audio_tokens - stable_tokens, + ) + else: + # Cache miss or stale -- encode the full audio + audio_embeds = asr.model.thinker.get_audio_features( + input_features, feature_attention_mask=feature_attention_mask, + ) + if audio_embeds.dim() == 3: + audio_embeds = audio_embeds[0] + logger.info( + "Audio cache MISS: encoded full %d tokens (was: %d stable cached)", + total_audio_tokens, cache.stable_tokens if cache.embeddings is not None else 0, + ) + + # Step 5: Update cache for next call. + cache.embeddings = audio_embeds.unsqueeze(0) # (1, n_tokens, hidden) + cache.encoded_samples = len(state.audio_buffer) + cache.encoded_mel_frames = total_mel_frames + cache.stable_tokens = stable_tokens + + return audio_embeds # (n_tokens, hidden_dim) + + except Exception as e: + logger.warning("Audio cache encoding failed, falling back: %s", e) + cache.reset() + return None + + def _build_inputs_with_cached_audio( + self, audio_embeds: torch.Tensor, + ) -> Optional[dict]: + """Build generate() inputs using pre-computed audio embeddings. + + Instead of passing ``input_features`` (which triggers the audio encoder + inside the model's forward), we: + 1. Tokenize the text prompt to get ``input_ids`` + 2. Embed the text tokens via ``get_input_embeddings()`` + 3. Replace audio placeholder positions with ``audio_embeds`` + 4. Append committed context token embeddings + 5. Return ``inputs_embeds`` + ``attention_mask`` (no ``input_ids``, + no ``input_features``) + + Returns None if the construction fails (caller falls back). + """ + asr = self.asr + state = self.state + thinker = asr.model.thinker + + try: + from qwen_asr.core.transformers_backend.processing_qwen3_asr import ( + _get_feat_extract_output_lengths, + ) + + n_audio_tokens = audio_embeds.shape[0] + + # Tokenize the text prompt with the correct number of audio + # placeholder tokens. The processor's + # ``replace_multimodal_special_tokens`` expands the single + # <|audio_pad|> into the right count. + prompt_with_placeholders = asr.processor.replace_multimodal_special_tokens( + [self._base_prompt], + iter([n_audio_tokens]), + )[0] + text_ids = asr.processor.tokenizer( + [prompt_with_placeholders], + return_tensors="pt", + padding=True, + ) + input_ids = text_ids["input_ids"].to(asr.device) + attention_mask = text_ids.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(asr.device) + + # Append committed context tokens + if state.committed_token_ids: + ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:] + ctx_ids = torch.tensor( + [ctx], dtype=input_ids.dtype, device=input_ids.device, + ) + input_ids = torch.cat([input_ids, ctx_ids], dim=1) + if attention_mask is not None: + ctx_mask = torch.ones_like(ctx_ids) + attention_mask = torch.cat([attention_mask, ctx_mask], dim=1) + + # Build inputs_embeds: embed text tokens, then scatter audio embeds + inputs_embeds = thinker.get_input_embeddings()(input_ids) + + # Find audio placeholder positions + audio_mask = (input_ids == asr.audio_token_id) + n_placeholders = audio_mask.sum().item() + + if n_placeholders != n_audio_tokens: + logger.warning( + "Audio token mismatch: %d placeholders vs %d embeddings", + n_placeholders, n_audio_tokens, + ) + return None + + # Scatter audio embeddings into placeholder positions + audio_embeds_for_scatter = audio_embeds.to( + inputs_embeds.device, inputs_embeds.dtype, + ) + expand_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter( + expand_mask, audio_embeds_for_scatter, + ) + + result = { + "inputs_embeds": inputs_embeds, + "input_ids": input_ids, # needed for position_ids/rope computation + } + if attention_mask is not None: + result["attention_mask"] = attention_mask + + return result + + except Exception as e: + logger.warning("Failed to build inputs with cached audio: %s", e) + return None + @torch.inference_mode() def process_iter(self, is_last=False) -> Tuple[List[ASRToken], float]: """ @@ -405,7 +710,8 @@ class Qwen3SimulStreamingOnlineProcessor: return [], self.end # Throttle: skip inference if less than 1s of new audio since last run. - # Each inference re-encodes the full buffer, so calling too often wastes + # Audio embedding caching avoids re-encoding the stable prefix, but + # the decoder still runs a full prefill, so calling too often wastes # GPU/CPU time and causes lag to spiral. new_samples = len(self.state.audio_buffer) - self.state.last_infer_samples min_new_seconds = 1.0 @@ -435,39 +741,86 @@ class Qwen3SimulStreamingOnlineProcessor: during generation. The Qwen3-ASR decoder layer discards attention weights (hidden_states, _ = self.self_attn(...)), so output_attentions via generate() would return None. Hooks capture them before discard. + + Audio embedding caching: instead of re-encoding the entire audio buffer + through the audio encoder on every call, we cache embeddings for the + stable prefix (complete attention windows) and only re-encode the tail. + This reduces the audio encoding cost from O(n) to O(1) per call for + the stable prefix, changing overall complexity from O(n^2) to O(n). """ asr = self.asr state = self.state - # Prepare inputs - inputs = asr.processor( - text=[self._base_prompt], - audio=[state.audio_buffer], - return_tensors="pt", - padding=True, - ) - inputs = inputs.to(asr.device).to(asr.dtype) + # --- Prepare inputs (with audio embedding cache) --- + # + # Try the cached path first: encode audio incrementally, then build + # inputs_embeds directly. If anything fails, fall back to the original + # processor-based path. + use_cached_path = False + audio_embeds = self._encode_audio_cached() + if audio_embeds is not None: + cached_inputs = self._build_inputs_with_cached_audio(audio_embeds) + if cached_inputs is not None: + input_ids_for_pos = cached_inputs["input_ids"] + inputs_embeds = cached_inputs["inputs_embeds"] - # Append committed token IDs as context so generate() continues from - # where it left off. Cap at max_context_tokens to prevent prompt growth. - if state.committed_token_ids: - ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:] - ctx_ids = torch.tensor( - [ctx], dtype=inputs.input_ids.dtype, - device=inputs.input_ids.device, + # Build the inputs dict for generate(). + # We pass BOTH input_ids and inputs_embeds. The model's forward() + # checks: if inputs_embeds is not None, it skips embedding lookup. + # But input_ids is still needed for: + # - Finding audio placeholder positions (get_placeholder_mask) + # - Computing position_ids / rope_deltas + # We set input_features=None so the model does NOT re-run the + # audio encoder. + inputs = { + "input_ids": input_ids_for_pos, + "inputs_embeds": inputs_embeds, + "attention_mask": cached_inputs.get("attention_mask"), + } + # Remove None values + inputs = {k: v for k, v in inputs.items() if v is not None} + use_cached_path = True + + if not use_cached_path: + # Fallback: original processor-based path (full re-encoding) + logger.info("Using fallback (non-cached) audio encoding path") + state.audio_cache.reset() + inputs = asr.processor( + text=[self._base_prompt], + audio=[state.audio_buffer], + return_tensors="pt", + padding=True, ) - inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1) - if "attention_mask" in inputs: - ctx_mask = torch.ones_like(ctx_ids) - inputs["attention_mask"] = torch.cat( - [inputs.attention_mask, ctx_mask], dim=1, + inputs = inputs.to(asr.device).to(asr.dtype) + + # Append committed token IDs as context + if state.committed_token_ids: + ctx = state.committed_token_ids[-asr.cfg.max_context_tokens:] + ctx_ids = torch.tensor( + [ctx], dtype=inputs.input_ids.dtype, + device=inputs.input_ids.device, ) + inputs["input_ids"] = torch.cat([inputs.input_ids, ctx_ids], dim=1) + if "attention_mask" in inputs: + ctx_mask = torch.ones_like(ctx_ids) + inputs["attention_mask"] = torch.cat( + [inputs.attention_mask, ctx_mask], dim=1, + ) - prompt_len = inputs.input_ids.shape[1] + # prompt_len = number of tokens in the input sequence (for slicing + # generated tokens from the output). generate() constructs output + # starting from input_ids, so use input_ids.shape[1] in both paths. + if use_cached_path: + prompt_len = inputs["input_ids"].shape[1] + else: + prompt_len = inputs.input_ids.shape[1] - # Find audio token range - input_ids = inputs.input_ids[0] - audio_mask = (input_ids == asr.audio_token_id) + # Find audio token range from input_ids + if use_cached_path: + ids_for_audio_range = inputs["input_ids"][0] + else: + ids_for_audio_range = inputs.input_ids[0] + audio_mask = (ids_for_audio_range == asr.audio_token_id) audio_positions = audio_mask.nonzero(as_tuple=True)[0] if len(audio_positions) == 0: return []