|
5 | 5 | # |
6 | 6 |
|
7 | 7 | import argparse |
8 | | -import asyncio |
9 | 8 | import sys |
| 9 | +import uuid |
10 | 10 | from contextlib import asynccontextmanager |
11 | | -from typing import Dict |
| 11 | +from typing import Any, Dict, List, Optional, TypedDict |
12 | 12 |
|
13 | 13 | import uvicorn |
14 | 14 | from bot import run_bot |
15 | 15 | from dotenv import load_dotenv |
16 | | -from fastapi import BackgroundTasks, FastAPI |
| 16 | +from fastapi import BackgroundTasks, FastAPI, Request |
17 | 17 | from fastapi.responses import RedirectResponse |
18 | 18 | from loguru import logger |
19 | | -from pipecat.transports.smallwebrtc.connection import IceServer, SmallWebRTCConnection |
| 19 | +from pipecat.transports.smallwebrtc.connection import IceServer |
| 20 | +from pipecat.transports.smallwebrtc.request_handler import ( |
| 21 | + SmallWebRTCPatchRequest, |
| 22 | + SmallWebRTCRequest, |
| 23 | + SmallWebRTCRequestHandler, |
| 24 | +) |
20 | 25 | from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI |
21 | 26 |
|
22 | 27 | # Load environment variables |
23 | 28 | load_dotenv(override=True) |
24 | 29 |
|
25 | 30 | app = FastAPI() |
26 | 31 |
|
27 | | -# Store connections by pc_id |
28 | | -pcs_map: Dict[str, SmallWebRTCConnection] = {} |
29 | | - |
30 | | -ice_servers = [ |
31 | | - IceServer( |
32 | | - urls="stun:stun.l.google.com:19302", |
33 | | - ) |
34 | | -] |
35 | | - |
36 | 32 | # Mount the frontend at / |
37 | 33 | app.mount("/prebuilt", SmallWebRTCPrebuiltUI) |
38 | 34 |
|
| 35 | +# Initialize the SmallWebRTC request handler |
| 36 | +small_webrtc_handler: SmallWebRTCRequestHandler = SmallWebRTCRequestHandler() |
| 37 | + |
| 38 | +# In-memory store of active sessions: session_id -> session info |
| 39 | +active_sessions: Dict[str, Dict[str, Any]] = {} |
| 40 | + |
39 | 41 |
|
40 | 42 | @app.get("/", include_in_schema=False) |
41 | 43 | async def root_redirect(): |
42 | 44 | return RedirectResponse(url="/prebuilt/") |
43 | 45 |
|
44 | 46 |
|
45 | 47 | @app.post("/api/offer") |
46 | | -async def offer(request: dict, background_tasks: BackgroundTasks): |
47 | | - pc_id = request.get("pc_id") |
48 | | - |
49 | | - if pc_id and pc_id in pcs_map: |
50 | | - pipecat_connection = pcs_map[pc_id] |
51 | | - logger.info(f"Reusing existing connection for pc_id: {pc_id}") |
52 | | - await pipecat_connection.renegotiate( |
53 | | - sdp=request["sdp"], type=request["type"], restart_pc=request.get("restart_pc", False) |
54 | | - ) |
55 | | - else: |
56 | | - pipecat_connection = SmallWebRTCConnection(ice_servers) |
57 | | - await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"]) |
| 48 | +async def offer(request: SmallWebRTCRequest, background_tasks: BackgroundTasks): |
| 49 | + """Handle WebRTC offer requests via SmallWebRTCRequestHandler.""" |
58 | 50 |
|
59 | | - @pipecat_connection.event_handler("closed") |
60 | | - async def handle_disconnected(webrtc_connection: SmallWebRTCConnection): |
61 | | - logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}") |
62 | | - pcs_map.pop(webrtc_connection.pc_id, None) |
| 51 | + # Prepare runner arguments with the callback to run your bot |
| 52 | + async def webrtc_connection_callback(connection): |
| 53 | + background_tasks.add_task(run_bot, connection) |
63 | 54 |
|
64 | | - background_tasks.add_task(run_bot, pipecat_connection) |
| 55 | + # Delegate handling to SmallWebRTCRequestHandler |
| 56 | + answer = await small_webrtc_handler.handle_web_request( |
| 57 | + request=request, |
| 58 | + webrtc_connection_callback=webrtc_connection_callback, |
| 59 | + ) |
| 60 | + return answer |
65 | 61 |
|
66 | | - answer = pipecat_connection.get_answer() |
67 | | - # Updating the peer connection inside the map |
68 | | - pcs_map[answer["pc_id"]] = pipecat_connection |
69 | 62 |
|
70 | | - return answer |
| 63 | +@app.patch("/api/offer") |
| 64 | +async def ice_candidate(request: SmallWebRTCPatchRequest): |
| 65 | + """Handle WebRTC new ice candidate requests.""" |
| 66 | + logger.debug(f"Received patch request: {request}") |
| 67 | + await small_webrtc_handler.handle_patch_request(request) |
| 68 | + return {"status": "success"} |
| 69 | + |
| 70 | + |
| 71 | +@app.post("/start") |
| 72 | +async def rtvi_start(request: Request): |
| 73 | + """Mimic Pipecat Cloud's /start endpoint.""" |
| 74 | + |
| 75 | + class IceConfig(TypedDict): |
| 76 | + iceServers: List[IceServer] |
| 77 | + |
| 78 | + class StartBotResult(TypedDict, total=False): |
| 79 | + sessionId: str |
| 80 | + iceConfig: Optional[IceConfig] |
| 81 | + |
| 82 | + # Parse the request body |
| 83 | + try: |
| 84 | + request_data = await request.json() |
| 85 | + logger.debug(f"Received request: {request_data}") |
| 86 | + except Exception as e: |
| 87 | + logger.error(f"Failed to parse request body: {e}") |
| 88 | + request_data = {} |
| 89 | + |
| 90 | + # Store session info immediately in memory, replicate the behavior expected on Pipecat Cloud |
| 91 | + session_id = str(uuid.uuid4()) |
| 92 | + active_sessions[session_id] = request_data |
| 93 | + |
| 94 | + result: StartBotResult = {"sessionId": session_id} |
| 95 | + if request_data.get("enableDefaultIceServers"): |
| 96 | + result["iceConfig"] = IceConfig(iceServers=[IceServer(urls="stun:stun.l.google.com:19302")]) |
| 97 | + |
| 98 | + return result |
71 | 99 |
|
72 | 100 |
|
73 | 101 | @asynccontextmanager |
74 | 102 | async def lifespan(app: FastAPI): |
75 | 103 | yield # Run app |
76 | | - coros = [pc.disconnect() for pc in pcs_map.values()] |
77 | | - await asyncio.gather(*coros) |
78 | | - pcs_map.clear() |
| 104 | + await small_webrtc_handler.close() |
79 | 105 |
|
80 | 106 |
|
81 | 107 | if __name__ == "__main__": |
|
0 commit comments