Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Complete functionality to Protect Chat API route #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import { NextRequest, NextResponse } from 'next/server';
import { Message as VercelChatMessage, StreamingTextResponse } from 'ai';

import { createStuffDocumentsChain } from "langchain/chains/combine_documents";
import { createRetrievalChain } from "langchain/chains/retrieval";
import { createHistoryAwareRetriever } from "langchain/chains/history_aware_retriever";
import { createStuffDocumentsChain } from 'langchain/chains/combine_documents';
import { createRetrievalChain } from 'langchain/chains/retrieval';
import { createHistoryAwareRetriever } from 'langchain/chains/history_aware_retriever';

import { HumanMessage, AIMessage, ChatMessage } from "@langchain/core/messages";
import { HumanMessage, AIMessage, ChatMessage } from '@langchain/core/messages';
import { ChatTogetherAI } from '@langchain/community/chat_models/togetherai';
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts';
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from '@langchain/core/prompts';
import { PineconeStore } from '@langchain/pinecone';
import { Document } from '@langchain/core/documents';
import { RunnableSequence, RunnablePick } from '@langchain/core/runnables';
import { TogetherAIEmbeddings } from '@langchain/community/embeddings/togetherai';
import {
HttpResponseOutputParser,
} from 'langchain/output_parsers';
import { HttpResponseOutputParser } from 'langchain/output_parsers';

import { Pinecone } from '@pinecone-database/pinecone';
import { auth } from '@clerk/nextjs';

export const runtime = 'edge';

Expand All @@ -30,17 +32,19 @@ const formatVercelMessages = (message: VercelChatMessage) => {
} else if (message.role === 'assistant') {
return new AIMessage(message.content);
} else {
console.warn(`Unknown message type passed: "${message.role}". Falling back to generic message type.`);
console.warn(
`Unknown message type passed: "${message.role}". Falling back to generic message type.`,
);
return new ChatMessage({ content: message.content, role: message.role });
}
};

const historyAwarePrompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder("chat_history"),
["user", "{input}"],
new MessagesPlaceholder('chat_history'),
['user', '{input}'],
[
"user",
"Given the above conversation, generate a concise vector store search query to look up in order to get information relevant to the conversation.",
'user',
'Given the above conversation, generate a concise vector store search query to look up in order to get information relevant to the conversation.',
],
]);

Expand All @@ -55,9 +59,9 @@ If the question is not related to the context, politely respond that you are tun
Please return your answer in markdown with clear headings and lists.`;

const answerPrompt = ChatPromptTemplate.fromMessages([
["system", ANSWER_SYSTEM_TEMPLATE],
new MessagesPlaceholder("chat_history"),
["user", "{input}"],
['system', ANSWER_SYSTEM_TEMPLATE],
new MessagesPlaceholder('chat_history'),
['user', '{input}'],
]);

/**
Expand All @@ -69,10 +73,16 @@ const answerPrompt = ChatPromptTemplate.fromMessages([
*/
export async function POST(req: NextRequest) {
try {
const { userId } = auth();

if (!userId) {
return new Response('Unauthorized', { status: 401 });
}

const body = await req.json();
const messages = body.messages ?? [];
if (!messages.length) {
throw new Error("No messages provided.");
throw new Error('No messages provided.');
}
const formattedPreviousMessages = messages
.slice(0, -1)
Expand Down Expand Up @@ -139,8 +149,8 @@ export async function POST(req: NextRequest) {
// "Pick" the answer from the retrieval chain output object and stream it as bytes.
const outputChain = RunnableSequence.from([
conversationalRetrievalChain,
new RunnablePick({ keys: "answer" }),
new HttpResponseOutputParser({ contentType: "text/plain" }),
new RunnablePick({ keys: 'answer' }),
new HttpResponseOutputParser({ contentType: 'text/plain' }),
]);

const stream = await outputChain.stream({
Expand Down