Merge pull request #55 from QuentinFuxa/diart_integration_improvements

Diart integration improvements : Correct bugs
This commit is contained in:
Quentin Fuxa 2025-02-23 23:16:10 +01:00 committed by GitHub
commit d4096e7e11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 407 additions and 415 deletions

View file

@ -1,26 +1,27 @@
import asyncio
import re
import threading
import numpy as np
from diart import SpeakerDiarization
from diart.inference import StreamingInference
from diart.sources import AudioSource
from rx.subject import Subject
import threading
import numpy as np
import asyncio
import re
def extract_number(s):
match = re.search(r'\d+', s)
return int(match.group()) if match else None
def extract_number(s: str) -> int:
m = re.search(r'\d+', s)
return int(m.group()) if m else None
class WebSocketAudioSource(AudioSource):
"""
Simple custom AudioSource that blocks in read()
until close() is called.
push_audio() is used to inject new PCM chunks.
Custom AudioSource that blocks in read() until close() is called.
Use push_audio() to inject PCM chunks.
"""
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
super().__init__(uri, sample_rate)
self._close_event = threading.Event()
self._closed = False
self._close_event = threading.Event()
def read(self):
self._close_event.wait()
@ -32,99 +33,59 @@ class WebSocketAudioSource(AudioSource):
self._close_event.set()
def push_audio(self, chunk: np.ndarray):
chunk = np.expand_dims(chunk, axis=0)
if not self._closed:
self.stream.on_next(chunk)
self.stream.on_next(np.expand_dims(chunk, axis=0))
def create_pipeline(SAMPLE_RATE):
diar_pipeline = SpeakerDiarization()
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
inference = StreamingInference(
pipeline=diar_pipeline,
source=ws_source,
do_plot=False,
show_progress=False,
)
return inference, ws_source
def init_diart(SAMPLE_RATE, diar_instance):
diar_pipeline = SpeakerDiarization()
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
inference = StreamingInference(
pipeline=diar_pipeline,
source=ws_source,
do_plot=False,
show_progress=False,
)
l_speakers_queue = asyncio.Queue()
def diar_hook(result):
"""
Hook called each time Diart processes a chunk.
result is (annotation, audio).
For each detected speaker segment, push its info to the queue and update processed_time.
"""
annotation, audio = result
if annotation._labels:
for speaker in annotation._labels:
segments_beg = annotation._labels[speaker].segments_boundaries_[0]
segments_end = annotation._labels[speaker].segments_boundaries_[-1]
if segments_end > diar_instance.processed_time:
diar_instance.processed_time = segments_end
asyncio.create_task(
l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
)
else:
audio_duration = audio.extent.end
if audio_duration > diar_instance.processed_time:
diar_instance.processed_time = audio_duration
inference.attach_hooks(diar_hook)
loop = asyncio.get_event_loop()
diar_future = loop.run_in_executor(None, inference)
return inference, l_speakers_queue, ws_source
class DiartDiarization:
def __init__(self, SAMPLE_RATE):
self.processed_time = 0
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE, self)
def __init__(self, sample_rate: int):
self.processed_time = 0
self.segment_speakers = []
self.speakers_queue = asyncio.Queue()
self.pipeline = SpeakerDiarization()
self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
self.inference = StreamingInference(
pipeline=self.pipeline,
source=self.source,
do_plot=False,
show_progress=False,
)
# Attache la fonction hook et démarre l'inférence en arrière-plan.
self.inference.attach_hooks(self._diar_hook)
asyncio.get_event_loop().run_in_executor(None, self.inference)
async def diarize(self, pcm_array):
self.ws_source.push_audio(pcm_array)
self.segment_speakers = []
while not self.l_speakers_queue.empty():
self.segment_speakers.append(await self.l_speakers_queue.get())
def _diar_hook(self, result):
annotation, audio = result
if annotation._labels:
for speaker, label in annotation._labels.items():
beg = label.segments_boundaries_[0]
end = label.segments_boundaries_[-1]
if end > self.processed_time:
self.processed_time = end
asyncio.create_task(self.speakers_queue.put({
"speaker": speaker,
"beg": beg,
"end": end
}))
else:
dur = audio.extent.end
if dur > self.processed_time:
self.processed_time = dur
async def diarize(self, pcm_array: np.ndarray):
self.source.push_audio(pcm_array)
self.segment_speakers.clear()
while not self.speakers_queue.empty():
self.segment_speakers.append(await self.speakers_queue.get())
def close(self):
self.ws_source.close()
self.source.close()
def assign_speakers_to_chunks(self, chunks):
"""
For each chunk (a dict with keys "beg" and "end"), assign a speaker label.
- If a chunk overlaps with a detected speaker segment, assign that label.
- If the chunk's end time is within the processed time and no speaker was assigned,
mark it as "No speaker".
- If the chunk's time hasn't been fully processed yet, leave it (or mark as "Processing").
"""
for ch in chunks:
ch["speaker"] = ch.get("speaker", -1)
for segment in self.segment_speakers:
seg_beg = segment["beg"]
seg_end = segment["end"]
speaker = segment["speaker"]
for ch in chunks:
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
continue
ch["speaker"] = extract_number(speaker) + 1
if self.processed_time > 0:
for ch in chunks:
if ch["end"] <= self.processed_time and ch["speaker"] == -1:
ch["speaker"] = -2
return chunks
def assign_speakers_to_chunks(self, chunks: list) -> list:
end_attributed_speaker = 0
for chunk in chunks:
for segment in self.segment_speakers:
if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]):
chunk["speaker"] = extract_number(segment["speaker"]) + 1
end_attributed_speaker = chunk["end"]
return end_attributed_speaker

View file

@ -1,339 +1,361 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<title>Audio Transcription</title>
<style>
body {
font-family: 'Inter', sans-serif;
margin: 20px;
text-align: center;
}
#recordButton {
width: 80px;
height: 80px;
font-size: 36px;
border: none;
border-radius: 50%;
background-color: white;
cursor: pointer;
box-shadow: 0 0px 10px rgba(0, 0, 0, 0.2);
transition: background-color 0.3s ease, transform 0.2s ease;
}
#recordButton.recording {
background-color: #ff4d4d;
color: white;
}
#recordButton:active {
transform: scale(0.95);
}
#status {
margin-top: 20px;
font-size: 16px;
color: #333;
}
.settings-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
margin-top: 20px;
}
.settings {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 5px;
}
#chunkSelector,
#websocketInput {
font-size: 16px;
padding: 5px;
border-radius: 5px;
border: 1px solid #ddd;
background-color: #f9f9f9;
}
#websocketInput {
width: 200px;
}
#chunkSelector:focus,
#websocketInput:focus {
outline: none;
border-color: #007bff;
}
label {
font-size: 14px;
}
/* Speaker-labeled transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 600px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 5px 0;
}
#linesTranscript strong {
color: #333;
}
#speaker {
background-color: #dcefff;
border-radius: 30px;
padding: 2px 10px;
font-size: 14px;
}
#timeInfo {
color: #666;
margin-left: 10px;
}
.textcontent {
font-size: 16px;
margin-left: 10px;
padding-left: 10px;
border-left: 2px solid #dcefff;
margin-bottom: 10px;
}
.buffer {
color: rgb(180, 180, 180);
font-style: italic;
margin-left: 4px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid rgba(0, 0, 0, 0.2);
border-top: 2px solid #333;
border-radius: 50%;
animation: spin 0.6s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Audio Transcription</title>
<style>
body {
font-family: 'Inter', sans-serif;
margin: 20px;
text-align: center;
}
#recordButton {
width: 80px;
height: 80px;
font-size: 36px;
border: none;
border-radius: 50%;
background-color: white;
cursor: pointer;
box-shadow: 0 0px 10px rgba(0, 0, 0, 0.2);
transition: background-color 0.3s ease, transform 0.2s ease;
}
#recordButton.recording {
background-color: #ff4d4d;
color: white;
}
#recordButton:active {
transform: scale(0.95);
}
#status {
margin-top: 20px;
font-size: 16px;
color: #333;
}
.settings-container {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
margin-top: 20px;
}
.settings {
display: flex;
flex-direction: column;
align-items: flex-start;
gap: 5px;
}
#chunkSelector,
#websocketInput {
font-size: 16px;
padding: 5px;
border-radius: 5px;
border: 1px solid #ddd;
background-color: #f9f9f9;
}
#websocketInput {
width: 200px;
}
#chunkSelector:focus,
#websocketInput:focus {
outline: none;
border-color: #007bff;
}
label {
font-size: 14px;
}
/* Speaker-labeled transcript area */
#linesTranscript {
margin: 20px auto;
max-width: 600px;
text-align: left;
font-size: 16px;
}
#linesTranscript p {
margin: 5px 0;
}
#linesTranscript strong {
color: #333;
}
#speaker {
background-color: #dcefff;
border-radius: 30px;
padding: 2px 10px;
font-size: 14px;
}
#timeInfo {
color: #666;
margin-left: 10px;
}
.textcontent {
font-size: 16px;
margin-left: 10px;
padding-left: 10px;
border-left: 2px solid #dcefff;
margin-bottom: 10px;
}
.buffer {
color: rgb(180, 180, 180);
font-style: italic;
margin-left: 4px;
}
.spinner {
display: inline-block;
width: 8px;
height: 8px;
border: 2px solid rgba(0, 0, 0, 0.2);
border-top: 2px solid #333;
border-radius: 50%;
animation: spin 0.6s linear infinite;
vertical-align: middle;
margin-bottom: 2px;
}
@keyframes spin {
to {
transform: rotate(360deg);
to {
transform: rotate(360deg);
}
}
.silence {
color: #666;
background-color: #f3f3f3;
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.silence {
color: #666;
background-color: #f3f3f3;
font-size: 13px;
border-radius: 30px;
padding: 2px 10px;
}
.loading {
color: #666;
background-color: #eff9ff;
font-size: 14px;
border-radius: 30px;
padding: 2px 10px;
}
</style>
.loading {
color: #666;
background-color: #eff9ff;
font-size: 14px;
border-radius: 30px;
padding: 2px 10px;
}
</style>
</head>
<body>
<div class="settings-container">
<button id="recordButton">🎙️</button>
<div class="settings">
<div>
<label for="chunkSelector">Chunk size (ms):</label>
<select id="chunkSelector">
<option value="500">500 ms</option>
<option value="1000" selected>1000 ms</option>
<option value="2000">2000 ms</option>
<option value="3000">3000 ms</option>
<option value="4000">4000 ms</option>
<option value="5000">5000 ms</option>
</select>
</div>
<div>
<label for="websocketInput">WebSocket URL:</label>
<input id="websocketInput" type="text" value="ws://localhost:8000/asr" />
</div>
<div class="settings-container">
<button id="recordButton">🎙️</button>
<div class="settings">
<div>
<label for="chunkSelector">Chunk size (ms):</label>
<select id="chunkSelector">
<option value="500">500 ms</option>
<option value="1000" selected>1000 ms</option>
<option value="2000">2000 ms</option>
<option value="3000">3000 ms</option>
<option value="4000">4000 ms</option>
<option value="5000">5000 ms</option>
</select>
</div>
<div>
<label for="websocketInput">WebSocket URL:</label>
<input id="websocketInput" type="text" value="ws://localhost:8000/asr" />
</div>
</div>
</div>
</div>
<p id="status"></p>
<p id="status"></p>
<!-- Speaker-labeled transcript -->
<div id="linesTranscript"></div>
<!-- Speaker-labeled transcript -->
<div id="linesTranscript"></div>
<script>
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
<script>
let isRecording = false;
let websocket = null;
let recorder = null;
let chunkDuration = 1000;
let websocketUrl = "ws://localhost:8000/asr";
let userClosing = false;
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const linesTranscriptDiv = document.getElementById("linesTranscript");
const statusText = document.getElementById("status");
const recordButton = document.getElementById("recordButton");
const chunkSelector = document.getElementById("chunkSelector");
const websocketInput = document.getElementById("websocketInput");
const linesTranscriptDiv = document.getElementById("linesTranscript");
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
chunkSelector.addEventListener("change", () => {
chunkDuration = parseInt(chunkSelector.value);
});
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
statusText.textContent = "WebSocket closed by user.";
} else {
statusText.textContent =
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
}
userClosing = false;
};
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
// Handle messages from server
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
/*
The server might send:
{
"lines": [
{"speaker": 0, "text": "Hello.", "beg": "00:00", "end": "00:01"},
{"speaker": -2, "text": "Hi, no speaker here.", "beg": "00:01", "end": "00:02"},
{"speaker": -1, "text": "...", "beg": "00:02", "end": "00:03" },
...
],
"buffer": "..."
websocketInput.addEventListener("change", () => {
const urlValue = websocketInput.value.trim();
if (!urlValue.startsWith("ws://") && !urlValue.startsWith("wss://")) {
statusText.textContent = "Invalid WebSocket URL (must start with ws:// or wss://)";
return;
}
*/
const { lines = [], buffer = "" } = data;
renderLinesWithBuffer( lines, buffer);
};
});
}
websocketUrl = urlValue;
statusText.textContent = "WebSocket URL updated. Ready to connect.";
});
function renderLinesWithBuffer(lines, buffer) {
// Clears if no lines
if (!Array.isArray(lines) || lines.length === 0) {
linesTranscriptDiv.innerHTML = "";
return;
}
function setupWebSocket() {
return new Promise((resolve, reject) => {
try {
websocket = new WebSocket(websocketUrl);
} catch (error) {
statusText.textContent = "Invalid WebSocket URL. Please check and try again.";
reject(error);
return;
}
websocket.onopen = () => {
statusText.textContent = "Connected to server.";
resolve();
};
websocket.onclose = () => {
if (userClosing) {
statusText.textContent = "WebSocket closed by user.";
} else {
statusText.textContent =
"Disconnected from the WebSocket server. (Check logs if model is loading.)";
}
userClosing = false;
};
const linesHtml = lines.map((item, idx) => {
let timeInfo = "";
if (item.beg !== undefined && item.end !== undefined) {
timeInfo = ` ${item.beg} - ${item.end}`;
websocket.onerror = () => {
statusText.textContent = "Error connecting to WebSocket.";
reject(new Error("Error connecting to WebSocket"));
};
// Handle messages from server
websocket.onmessage = (event) => {
const data = JSON.parse(event.data);
/*
The server might send:
{
"lines": [
{"speaker": 0, "text": "Hello.", "beg": "00:00", "end": "00:01"},
{"speaker": -2, "text": "Hi, no speaker here.", "beg": "00:01", "end": "00:02"},
{"speaker": -1, "text": "...", "beg": "00:02", "end": "00:03" },
...
],
"buffer": "..."
}
*/
const { lines = [], buffer = "" } = data;
renderLinesWithBuffer(lines, buffer);
};
});
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == -1) {
speakerLabel = `<span class='loading'> <span class="spinner"></span><span id='timeInfo'>${item.diff} second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker == -3) {
speakerLabel = `<span id="speaker"><span id='timeInfo'>${timeInfo}</span>`;
} else if (item.speaker !== -1) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
function renderLinesWithBuffer(lines, buffer) {
if (!Array.isArray(lines) || lines.length === 0) {
if (buffer) {
linesTranscriptDiv.innerHTML = `<span class="buffer">${buffer}</span>`;
} else {
linesTranscriptDiv.innerHTML = "";
}
return;
}
let textContent = item.text;
if (idx === lines.length - 1 && buffer) {
textContent += `<span class="buffer">${buffer}</span>`;
const linesHtml = lines.map((item, idx) => {
let timeInfo = "";
if (item.beg !== undefined && item.end !== undefined) {
timeInfo = ` ${item.beg} - ${item.end}`;
}
let speakerLabel = "";
if (item.speaker === -2) {
speakerLabel = `<span class="silence">Silence<span id='timeInfo'>${timeInfo}</span></span>`;
} else if (item.speaker == 0) {
speakerLabel = `<span class='loading'><span class="spinner"></span><span id='timeInfo'>${item.diff} second(s) of audio are undergoing diarization</span></span>`;
} else if (item.speaker == -1) {
speakerLabel = `<span id="speaker"><span id='timeInfo'>${timeInfo}</span>`;
} else if (item.speaker !== -1) {
speakerLabel = `<span id="speaker">Speaker ${item.speaker}<span id='timeInfo'>${timeInfo}</span></span>`;
}
let textContent = item.text;
if (idx === lines.length - 1 && buffer) {
textContent += `<span class="buffer">${buffer}</span>`;
}
return textContent
? `<p>${speakerLabel}<br/><div class='textcontent'>${textContent}</div></p>`
: `<p>${speakerLabel}<br/></p>`;
}).join("");
linesTranscriptDiv.innerHTML = linesHtml;
}
return textContent
? `<p>${speakerLabel}<br/><div class='textcontent'>${textContent}</div></p>`
: `<p >${speakerLabel}<br/></p>`;
}).join("");
linesTranscriptDiv.innerHTML = linesHtml;
}
async function startRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
};
recorder.start(chunkDuration);
isRecording = true;
updateUI();
} catch (err) {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
}
function stopRecording() {
userClosing = true;
if (recorder) {
recorder.stop();
recorder = null;
}
isRecording = false;
if (websocket) {
websocket.close();
websocket = null;
}
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
linesTranscriptDiv.innerHTML = "";
try {
await setupWebSocket();
await startRecording();
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
async function startRecording() {
try {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
recorder = new MediaRecorder(stream, { mimeType: "audio/webm" });
recorder.ondataavailable = (e) => {
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
};
recorder.start(chunkDuration);
isRecording = true;
updateUI();
} catch (err) {
statusText.textContent = "Error accessing microphone. Please allow microphone access.";
}
}
} else {
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
}
function stopRecording() {
userClosing = true;
if (recorder) {
recorder.stop();
recorder = null;
}
isRecording = false;
recordButton.addEventListener("click", toggleRecording);
</script>
if (websocket) {
websocket.close();
websocket = null;
}
updateUI();
}
async function toggleRecording() {
if (!isRecording) {
linesTranscriptDiv.innerHTML = "";
try {
await setupWebSocket();
await startRecording();
} catch (err) {
statusText.textContent = "Could not connect to WebSocket or access mic. Aborted.";
}
} else {
stopRecording();
}
}
function updateUI() {
recordButton.classList.toggle("recording", isRecording);
statusText.textContent = isRecording ? "Recording..." : "Click to start transcription";
}
recordButton.addEventListener("click", toggleRecording);
</script>
</body>
</html>

View file

@ -208,6 +208,7 @@ async def websocket_endpoint(websocket: WebSocket):
"beg": transcription.start,
"end": transcription.end,
"text": transcription.text,
"speaker": -1
})
full_transcription += transcription.text if transcription else ""
buffer = online.get_buffer()
@ -218,23 +219,32 @@ async def websocket_endpoint(websocket: WebSocket):
"beg": time() - beg_loop,
"end": time() - beg_loop + 1,
"text": '',
"speaker": -1
})
sleep(1)
buffer = ''
if args.diarization:
await diarization.diarize(pcm_array)
diarization.assign_speakers_to_chunks(chunk_history)
end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history)
current_speaker = 0
current_speaker = -10
lines = []
last_end_diarized = 0
previous_speaker = -1
for ind, ch in enumerate(chunk_history):
speaker = ch.get("speaker", -3)
if speaker == -1 and ind < len(chunk_history) - 1:
continue
elif speaker != current_speaker:
speaker = ch.get("speaker")
if args.diarization:
if speaker == -1 or speaker == 0:
if ch['end'] < end_attributed_speaker:
speaker = previous_speaker
else:
speaker = 0
else:
last_end_diarized = max(ch['end'], last_end_diarized)
if speaker != current_speaker:
lines.append(
{
"speaker": speaker,
@ -245,12 +255,11 @@ async def websocket_endpoint(websocket: WebSocket):
}
)
current_speaker = speaker
elif speaker != -1:
else:
lines[-1]["text"] += ch['text']
lines[-1]["end"] = format_time(ch['end'])
if speaker != -1:
last_end_diarized = max(ch['end'], last_end_diarized)
lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2)
response = {"lines": lines, "buffer": buffer}
await websocket.send_json(response)