Lint scripts and tests
This commit is contained in:
parent
cf6c49f502
commit
74c4dc791d
5 changed files with 17 additions and 17 deletions
|
|
@ -6,6 +6,7 @@ Produces one JSON file per audio with: [{word, start, end}, ...]
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
|
AUDIO_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ sys.path.insert(0, str(Path(__file__).parent))
|
||||||
from test_backend_offline import (
|
from test_backend_offline import (
|
||||||
AUDIO_TESTS_DIR,
|
AUDIO_TESTS_DIR,
|
||||||
SAMPLE_RATE,
|
SAMPLE_RATE,
|
||||||
TestResult,
|
|
||||||
create_engine,
|
create_engine,
|
||||||
discover_audio_files,
|
discover_audio_files,
|
||||||
download_sample_audio,
|
download_sample_audio,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import io
|
||||||
import math
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Optional, Sequence, Tuple, Union
|
from typing import Sequence, Tuple, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -24,7 +24,7 @@ sys.path.insert(0, str(REPO_ROOT))
|
||||||
sys.path.insert(0, str(WHISPER_ROOT))
|
sys.path.insert(0, str(WHISPER_ROOT))
|
||||||
|
|
||||||
from whisper import load_model
|
from whisper import load_model
|
||||||
from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
|
from whisper.audio import log_mel_spectrogram, pad_or_trim
|
||||||
from whisper.tokenizer import get_tokenizer
|
from whisper.tokenizer import get_tokenizer
|
||||||
|
|
||||||
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
AudioInput = Union[str, pathlib.Path, np.ndarray, torch.Tensor]
|
||||||
|
|
@ -85,7 +85,7 @@ def _parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-config",
|
"--dataset-config",
|
||||||
type=str,
|
type=str,
|
||||||
default="clean"
|
default="clean"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dataset-split",
|
"--dataset-split",
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,39 @@
|
||||||
"""Copy core files from web directory to Chrome extension directory."""
|
"""Copy core files from web directory to Chrome extension directory."""
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def sync_extension_files():
|
def sync_extension_files():
|
||||||
|
|
||||||
web_dir = Path("whisperlivekit/web")
|
web_dir = Path("whisperlivekit/web")
|
||||||
extension_dir = Path("chrome-extension")
|
extension_dir = Path("chrome-extension")
|
||||||
|
|
||||||
files_to_sync = [
|
files_to_sync = [
|
||||||
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
"live_transcription.html", "live_transcription.js", "live_transcription.css"
|
||||||
]
|
]
|
||||||
|
|
||||||
svg_files = [
|
svg_files = [
|
||||||
"system_mode.svg",
|
"system_mode.svg",
|
||||||
"light_mode.svg",
|
"light_mode.svg",
|
||||||
"dark_mode.svg",
|
"dark_mode.svg",
|
||||||
"settings.svg"
|
"settings.svg"
|
||||||
]
|
]
|
||||||
|
|
||||||
for file in files_to_sync:
|
for file in files_to_sync:
|
||||||
src_path = web_dir / file
|
src_path = web_dir / file
|
||||||
dest_path = extension_dir / file
|
dest_path = extension_dir / file
|
||||||
|
|
||||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy2(src_path, dest_path)
|
shutil.copy2(src_path, dest_path)
|
||||||
|
|
||||||
for svg_file in svg_files:
|
for svg_file in svg_files:
|
||||||
src_path = web_dir / "src" / svg_file
|
src_path = web_dir / "src" / svg_file
|
||||||
dest_path = extension_dir / "web" / "src" / svg_file
|
dest_path = extension_dir / "web" / "src" / svg_file
|
||||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy2(src_path, dest_path)
|
shutil.copy2(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
sync_extension_files()
|
sync_extension_files()
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,8 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataclasses import dataclass, asdict, field
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -157,6 +157,7 @@ def create_engine(
|
||||||
):
|
):
|
||||||
"""Create a TranscriptionEngine with the given backend config."""
|
"""Create a TranscriptionEngine with the given backend config."""
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
from whisperlivekit.core import TranscriptionEngine
|
from whisperlivekit.core import TranscriptionEngine
|
||||||
|
|
||||||
# Reset singleton so we get a fresh instance
|
# Reset singleton so we get a fresh instance
|
||||||
|
|
@ -320,7 +321,7 @@ async def run_test(
|
||||||
transcription = _extract_text_from_response(last)
|
transcription = _extract_text_from_response(last)
|
||||||
|
|
||||||
# --- Compute WER and timestamp accuracy against ground truth ---
|
# --- Compute WER and timestamp accuracy against ground truth ---
|
||||||
from whisperlivekit.metrics import compute_wer, compute_timestamp_accuracy
|
from whisperlivekit.metrics import compute_timestamp_accuracy, compute_wer
|
||||||
|
|
||||||
wer_val = None
|
wer_val = None
|
||||||
wer_details = None
|
wer_details = None
|
||||||
|
|
@ -434,7 +435,7 @@ async def run_all_tests(
|
||||||
file_lan = lan
|
file_lan = lan
|
||||||
if "french" in audio_path.name.lower() and lan == "en":
|
if "french" in audio_path.name.lower() and lan == "en":
|
||||||
file_lan = "fr"
|
file_lan = "fr"
|
||||||
logger.info(f"Auto-detected language 'fr' from filename")
|
logger.info("Auto-detected language 'fr' from filename")
|
||||||
|
|
||||||
audio = load_audio(str(audio_path))
|
audio = load_audio(str(audio_path))
|
||||||
|
|
||||||
|
|
@ -495,7 +496,7 @@ def print_benchmark_summary(results: List[TestResult]):
|
||||||
print(f"{'=' * 110}")
|
print(f"{'=' * 110}")
|
||||||
|
|
||||||
# Print transcription excerpts
|
# Print transcription excerpts
|
||||||
print(f"\nTRANSCRIPTIONS:")
|
print("\nTRANSCRIPTIONS:")
|
||||||
print(f"{'-' * 110}")
|
print(f"{'-' * 110}")
|
||||||
for r in results:
|
for r in results:
|
||||||
excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
|
excerpt = r.transcription[:120] + "..." if len(r.transcription) > 120 else r.transcription
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue