|
8 | 8 | import sys |
9 | 9 | import uuid |
10 | 10 | from contextlib import asynccontextmanager |
| 11 | +from http import HTTPMethod |
11 | 12 | from typing import Any, Dict, List, Optional, TypedDict |
12 | 13 |
|
13 | 14 | import uvicorn |
14 | 15 | from bot import run_bot |
15 | 16 | from dotenv import load_dotenv |
16 | | -from fastapi import BackgroundTasks, FastAPI, Request |
| 17 | +from fastapi import BackgroundTasks, FastAPI, Request, Response |
17 | 18 | from fastapi.responses import RedirectResponse |
18 | 19 | from loguru import logger |
19 | 20 | from pipecat.transports.smallwebrtc.connection import IceServer |
20 | 21 | from pipecat.transports.smallwebrtc.request_handler import ( |
| 22 | + IceCandidate, |
21 | 23 | SmallWebRTCPatchRequest, |
22 | 24 | SmallWebRTCRequest, |
23 | 25 | SmallWebRTCRequestHandler, |
@@ -93,10 +95,50 @@ class StartBotResult(TypedDict, total=False): |
93 | 95 |
|
94 | 96 | result: StartBotResult = {"sessionId": session_id} |
95 | 97 | if request_data.get("enableDefaultIceServers"): |
96 | | - result["iceConfig"] = IceConfig(iceServers=[IceServer(urls="stun:stun.l.google.com:19302")]) |
| 98 | + result["iceConfig"] = IceConfig( |
| 99 | + iceServers=[IceServer(urls=["stun:stun.l.google.com:19302"])] |
| 100 | + ) |
97 | 101 |
|
98 | 102 | return result |
99 | 103 |
|
| 104 | +@app.api_route( |
| 105 | + "/sessions/{session_id}/{path:path}", |
| 106 | + methods=["GET", "POST", "PUT", "PATCH", "DELETE"], |
| 107 | +) |
| 108 | +async def proxy_request( |
| 109 | + session_id: str, path: str, request: Request, background_tasks: BackgroundTasks |
| 110 | +): |
| 111 | + """Mimic Pipecat Cloud's proxy.""" |
| 112 | + active_session = active_sessions.get(session_id) |
| 113 | + if active_session is None: |
| 114 | + return Response(content="Invalid or not-yet-ready session_id", status_code=404) |
| 115 | + |
| 116 | + if path.endswith("api/offer"): |
| 117 | + # Parse the request body and convert to SmallWebRTCRequest |
| 118 | + try: |
| 119 | + request_data = await request.json() |
| 120 | + if request.method == HTTPMethod.POST.value: |
| 121 | + webrtc_request = SmallWebRTCRequest( |
| 122 | + sdp=request_data["sdp"], |
| 123 | + type=request_data["type"], |
| 124 | + pc_id=request_data.get("pc_id"), |
| 125 | + restart_pc=request_data.get("restart_pc"), |
| 126 | + request_data=request_data, |
| 127 | + ) |
| 128 | + return await offer(webrtc_request, background_tasks) |
| 129 | + elif request.method == HTTPMethod.PATCH.value: |
| 130 | + patch_request = SmallWebRTCPatchRequest( |
| 131 | + pc_id=request_data["pc_id"], |
| 132 | + candidates=[IceCandidate(**c) for c in request_data.get("candidates", [])], |
| 133 | + ) |
| 134 | + return await ice_candidate(patch_request) |
| 135 | + except Exception as e: |
| 136 | + logger.error(f"Failed to parse WebRTC request: {e}") |
| 137 | + return Response(content="Invalid WebRTC request", status_code=400) |
| 138 | + |
| 139 | + logger.info(f"Received request for path: {path}") |
| 140 | + return Response(status_code=200) |
| 141 | + |
100 | 142 |
|
101 | 143 | @asynccontextmanager |
102 | 144 | async def lifespan(app: FastAPI): |
|
0 commit comments