From 23e6f2bcba254be7d39cfaed8221e59ec04752c2 Mon Sep 17 00:00:00 2001 From: Kunaaal13 Date: Fri, 16 Feb 2024 13:03:37 +0530 Subject: [PATCH] Added Chat API route protection --- app/api/chat/route.ts | 48 ++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index f6aa946..8b2d663 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -1,22 +1,24 @@ import { NextRequest, NextResponse } from 'next/server'; import { Message as VercelChatMessage, StreamingTextResponse } from 'ai'; -import { createStuffDocumentsChain } from "langchain/chains/combine_documents"; -import { createRetrievalChain } from "langchain/chains/retrieval"; -import { createHistoryAwareRetriever } from "langchain/chains/history_aware_retriever"; +import { createStuffDocumentsChain } from 'langchain/chains/combine_documents'; +import { createRetrievalChain } from 'langchain/chains/retrieval'; +import { createHistoryAwareRetriever } from 'langchain/chains/history_aware_retriever'; -import { HumanMessage, AIMessage, ChatMessage } from "@langchain/core/messages"; +import { HumanMessage, AIMessage, ChatMessage } from '@langchain/core/messages'; import { ChatTogetherAI } from '@langchain/community/chat_models/togetherai'; -import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'; +import { + ChatPromptTemplate, + MessagesPlaceholder, +} from '@langchain/core/prompts'; import { PineconeStore } from '@langchain/pinecone'; import { Document } from '@langchain/core/documents'; import { RunnableSequence, RunnablePick } from '@langchain/core/runnables'; import { TogetherAIEmbeddings } from '@langchain/community/embeddings/togetherai'; -import { - HttpResponseOutputParser, -} from 'langchain/output_parsers'; +import { HttpResponseOutputParser } from 'langchain/output_parsers'; import { Pinecone } from '@pinecone-database/pinecone'; +import { auth } from '@clerk/nextjs'; export const runtime = 'edge'; @@ -30,17 +32,19 @@ const formatVercelMessages = (message: VercelChatMessage) => { } else if (message.role === 'assistant') { return new AIMessage(message.content); } else { - console.warn(`Unknown message type passed: "${message.role}". Falling back to generic message type.`); + console.warn( + `Unknown message type passed: "${message.role}". Falling back to generic message type.`, + ); return new ChatMessage({ content: message.content, role: message.role }); } }; const historyAwarePrompt = ChatPromptTemplate.fromMessages([ - new MessagesPlaceholder("chat_history"), - ["user", "{input}"], + new MessagesPlaceholder('chat_history'), + ['user', '{input}'], [ - "user", - "Given the above conversation, generate a concise vector store search query to look up in order to get information relevant to the conversation.", + 'user', + 'Given the above conversation, generate a concise vector store search query to look up in order to get information relevant to the conversation.', ], ]); @@ -55,9 +59,9 @@ If the question is not related to the context, politely respond that you are tun Please return your answer in markdown with clear headings and lists.`; const answerPrompt = ChatPromptTemplate.fromMessages([ - ["system", ANSWER_SYSTEM_TEMPLATE], - new MessagesPlaceholder("chat_history"), - ["user", "{input}"], + ['system', ANSWER_SYSTEM_TEMPLATE], + new MessagesPlaceholder('chat_history'), + ['user', '{input}'], ]); /** @@ -69,10 +73,16 @@ const answerPrompt = ChatPromptTemplate.fromMessages([ */ export async function POST(req: NextRequest) { try { + const { userId } = auth(); + + if (!userId) { + return new Response('Unauthorized', { status: 401 }); + } + const body = await req.json(); const messages = body.messages ?? []; if (!messages.length) { - throw new Error("No messages provided."); + throw new Error('No messages provided.'); } const formattedPreviousMessages = messages .slice(0, -1) @@ -139,8 +149,8 @@ export async function POST(req: NextRequest) { // "Pick" the answer from the retrieval chain output object and stream it as bytes. const outputChain = RunnableSequence.from([ conversationalRetrievalChain, - new RunnablePick({ keys: "answer" }), - new HttpResponseOutputParser({ contentType: "text/plain" }), + new RunnablePick({ keys: 'answer' }), + new HttpResponseOutputParser({ contentType: 'text/plain' }), ]); const stream = await outputChain.stream({