diff --git a/server-cloudflare/models/openai.ts b/server-cloudflare/models/openai.ts index ff15b48..2f0ec61 100644 --- a/server-cloudflare/models/openai.ts +++ b/server-cloudflare/models/openai.ts @@ -1,7 +1,7 @@ import { DurableObject } from "cloudflare:workers"; import type { Env } from "../src/types"; import { createOpusPacketizer } from "../src/opus"; -import { getSystemPrompt } from "../src/prompt"; +import { getFirstMessagePrompt, getSystemPrompt } from "../src/prompt"; const AUDIO_OUTPUT_SAMPLE_RATE = 24_000; @@ -42,15 +42,20 @@ async function transcribePcm(env: Env, audio: Uint8Array): Promise { async function generateOpenAIReply( env: Env, - transcript: string, + transcript: string | null, history: OpenAIChatMessage[], ): Promise { const messages: OpenAIChatMessage[] = [ { role: "system", content: getSystemPrompt(env) }, ...history, - { role: "user", content: transcript }, ]; + if (transcript && transcript.trim().length > 0) { + messages.push({ role: "user", content: transcript }); + } else { + messages.push({ role: "user", content: getFirstMessagePrompt(env) }); + } + const response = await fetch("https://api.openai.com/v1/chat/completions", { method: "POST", headers: { @@ -98,6 +103,7 @@ export class ElatoOpenAiVoiceAgent extends DurableObject { private audioBuffer = new Uint8Array(0); private isGenerating = false; private opusPromise: Promise>> | null = null; + private hasStartedConversation = false; constructor(ctx: DurableObjectState, env: Env) { super(ctx, env); @@ -130,35 +136,8 @@ export class ElatoOpenAiVoiceAgent extends DurableObject { return this.opusPromise; } - private async handleTurn( - websocket: WebSocket, - ) { + private async streamAssistantReply(websocket: WebSocket, reply: string) { const opus = await this.getOpusPacketizer(websocket); - const pcm = this.audioBuffer; - this.resetBufferedAudio(); - - if (pcm.byteLength === 0) { - return; - } - - websocket.send(createServerMessage("AUDIO.COMMITTED")); - - const transcript = await transcribePcm(this.env, pcm); - if (!transcript) { - websocket.send(createServerMessage("RESPONSE.ERROR")); - return; - } - /* Add user transcript DB call here */ - - const session = await this.loadSessionState(); - const reply = await generateOpenAIReply(this.env, transcript, session.history); - session.history.push( - { role: "user", content: transcript }, - { role: "assistant", content: reply }, - ); - await this.saveSessionState(session); - /* Add AI transcript DB call here */ - opus.reset(); websocket.send(createServerMessage("RESPONSE.CREATED")); @@ -184,6 +163,58 @@ export class ElatoOpenAiVoiceAgent extends DurableObject { } } + private async handleTurn( + websocket: WebSocket, + ) { + const pcm = this.audioBuffer; + this.resetBufferedAudio(); + + if (pcm.byteLength === 0) { + return; + } + + websocket.send(createServerMessage("AUDIO.COMMITTED")); + + const transcript = await transcribePcm(this.env, pcm); + if (!transcript) { + websocket.send(createServerMessage("RESPONSE.ERROR")); + return; + } + /* Add user transcript DB call here */ + + const session = await this.loadSessionState(); + const reply = await generateOpenAIReply(this.env, transcript, session.history); + session.history.push( + { role: "user", content: transcript }, + { role: "assistant", content: reply }, + ); + await this.saveSessionState(session); + /* Add AI transcript DB call here */ + await this.streamAssistantReply(websocket, reply); + } + + private async startInitialTurn(websocket: WebSocket) { + if (this.hasStartedConversation || this.isGenerating) { + return; + } + + this.hasStartedConversation = true; + this.isGenerating = true; + + try { + const session = await this.loadSessionState(); + const reply = await generateOpenAIReply(this.env, null, session.history); + session.history.push({ role: "assistant", content: reply }); + await this.saveSessionState(session); + /* Add AI transcript DB call here */ + await this.streamAssistantReply(websocket, reply); + } catch { + websocket.send(createServerMessage("RESPONSE.ERROR")); + } finally { + this.isGenerating = false; + } + } + async fetch(request: Request): Promise { if (request.headers.get("Upgrade") !== "websocket") { return new Response("Expected websocket", { status: 426 }); @@ -194,6 +225,7 @@ export class ElatoOpenAiVoiceAgent extends DurableObject { server.accept(); server.send(JSON.stringify(createAuthMessage())); + void this.startInitialTurn(server); server.addEventListener("message", (event) => { void this.ctx.blockConcurrencyWhile(async () => { diff --git a/server-cloudflare/src/prompt.ts b/server-cloudflare/src/prompt.ts index ccd51d3..3d966a5 100644 --- a/server-cloudflare/src/prompt.ts +++ b/server-cloudflare/src/prompt.ts @@ -3,6 +3,13 @@ import type { Env } from "./types"; const DEFAULT_PROMPT = "You are an Elato voice companion. Keep responses concise, natural to speak aloud, and friendly for a realtime conversation."; +const DEFAULT_FIRST_MESSAGE = + "Start the conversation now with a short spoken greeting. Introduce yourself naturally in one sentence."; + export function getSystemPrompt(env: Env): string { return env.ELATO_OPENAI_SYSTEM_PROMPT?.trim() || DEFAULT_PROMPT; } + +export function getFirstMessagePrompt(env: Env): string { + return env.ELATO_OPENAI_FIRST_MESSAGE?.trim() || DEFAULT_FIRST_MESSAGE; +} diff --git a/server-cloudflare/src/types.ts b/server-cloudflare/src/types.ts index 016c2d3..b2d8b9a 100644 --- a/server-cloudflare/src/types.ts +++ b/server-cloudflare/src/types.ts @@ -4,5 +4,6 @@ export interface Env { OPENAI_API_KEY: string; ELATO_OPENAI_MODEL?: string; ELATO_OPENAI_SYSTEM_PROMPT?: string; + ELATO_OPENAI_FIRST_MESSAGE?: string; ElatoOpenAiVoiceAgent: DurableObjectNamespace; }