Skip to content

Commit

Permalink
fix typings
Browse files Browse the repository at this point in the history
  • Loading branch information
mishig25 committed Jan 12, 2024
1 parent 8ffee0e commit f8b2ec5
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/lib/buildPrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { downloadImgFile } from "./server/files/downloadFile";
import type { Conversation } from "./types/Conversation";
import RAGs from "./server/rag/rag";
import type { RagContext } from "./types/rag";
import type { RagContextWebSearch } from "./types/WebSearch";

export type BuildPromptMessage = Pick<Message, "from" | "content" | "files">;

Expand All @@ -26,7 +27,7 @@ export async function buildPrompt({
}: buildPromptOptions): Promise<string> {
if (ragContext) {
const { type: ragType } = ragContext;
messages = RAGs[ragType].buildPrompt(messages, ragContext);
messages = RAGs[ragType].buildPrompt(messages, ragContext as RagContextWebSearch);
}

// section to handle potential files input
Expand Down
8 changes: 5 additions & 3 deletions src/lib/components/chat/ChatMessage.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import type { Model } from "$lib/types/Model";
import OpenRAGResults from "../OpenRAGResults.svelte";
import type { RAGUpdate } from "$lib/types/MessageUpdate";
import type { RAGUpdate, WebSearchUpdate } from "$lib/types/MessageUpdate";
import { ragTypes } from "$lib/types/rag";
function sanitizeMd(md: string) {
Expand Down Expand Up @@ -117,8 +117,10 @@
$: downloadLink =
message.from === "user" ? `${$page.url.pathname}/message/${message.id}/prompt` : undefined;
$: webSearchSources =
ragUpdates && ragUpdates?.filter(({ messageType }) => messageType === "sources")?.[0]?.sources;
$: webSearchSources = (
ragUpdates &&
(ragUpdates?.filter(({ messageType }) => messageType === "sources")?.[0] as WebSearchUpdate)
)?.sources;
$: if (isCopied) {
setTimeout(() => {
Expand Down
10 changes: 6 additions & 4 deletions src/lib/server/endpoints/openai/endpointOai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { buildPrompt } from "$lib/buildPrompt";
import { OPENAI_API_KEY } from "$env/static/private";
import type { Endpoint } from "../endpoints";
import { format } from "date-fns";
import type { RagContextWebSearch } from "$lib/types/WebSearch";

export const endpointOAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand Down Expand Up @@ -56,9 +57,10 @@ export async function endpointOai(
} else if (completion === "chat_completions") {
return async ({ conversation }) => {
let messages = conversation.messages;
const webSearch = conversation.messages[conversation.messages.length - 1].webSearch;
const ragContext = conversation.messages[conversation.messages.length - 1].ragContext;

if (webSearch && webSearch.context) {
if (ragContext && ragContext.type === "webSearch") {
const webSearchContext = ragContext as RagContextWebSearch;
const lastMsg = messages.slice(-1)[0];
const messagesWithoutLastUsrMsg = messages.slice(0, -1);
const previousUserMessages = messages.filter((el) => el.from === "user").slice(0, -1);
Expand All @@ -74,9 +76,9 @@ export async function endpointOai(
...messagesWithoutLastUsrMsg,
{
from: "user",
content: `I searched the web using the query: ${webSearch.searchQuery}. Today is ${currentDate} and here are the results:
content: `I searched the web using the query: ${webSearchContext.searchQuery}. Today is ${currentDate} and here are the results:
=====================
${webSearch.context}
${webSearchContext.context}
=====================
${previousQuestions}
Answer the question: ${lastMsg.content}
Expand Down
5 changes: 3 additions & 2 deletions src/lib/types/MessageUpdate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export interface RAGUpdate {
}

export interface WebSearchUpdate extends RAGUpdate {
type: "websearch";
type: "webSearch";
messageType: RAGUpdate["messageType"] | "sources";
sources?: WebSearchSource[];
}
Expand All @@ -51,6 +51,7 @@ export type MessageUpdate =
| FinalAnswer
| TextStreamUpdate
| AgentUpdate
| RAGUpdate
| WebSearchUpdate
| PdfSearchUpdate
| StatusUpdate
| ErrorUpdate;
5 changes: 2 additions & 3 deletions src/routes/conversation/[id]/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import type { MessageUpdate, RAGUpdate } from "$lib/types/MessageUpdate";
import titleUpdate from "$lib/stores/titleUpdate";
import file2base64 from "$lib/utils/file2base64";
import { PdfUploadStatus, type PdfUpload } from "$lib/types/PdfChat.js";
import { ragTypes } from "$lib/types/rag.js";
import { PdfUploadStatus, type PdfUpload } from "$lib/types/PdfChat";
export let data;
let messages = data.messages;
Expand Down Expand Up @@ -197,7 +196,7 @@
lastMessage.content += update.token;
messages = [...messages];
}
} else if (ragTypes.includes(update.type)) {
} else if (update.type === "webSearch" || update.type === "pdfChat") {
RAGMessages = [...RAGMessages, update];
} else if (update.type === "status") {
if (update.status === "title" && update.message) {
Expand Down
11 changes: 8 additions & 3 deletions src/routes/conversation/[id]/+server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { summarize } from "$lib/server/summarize";
import { uploadImgFile } from "$lib/server/files/uploadFile";
import sizeof from "image-size";
import RAGs from "$lib/server/rag/rag";
import type { RagContext } from "$lib/types/rag";

export async function POST({ request, locals, params, getClientAddress }) {
const id = z.string().parse(params.id);
Expand Down Expand Up @@ -235,12 +236,16 @@ export async function POST({ request, locals, params, getClientAddress }) {
let webSearchResults: RagContextWebSearch | undefined;

if (webSearch) {
webSearchResults = await RAGs["webSearch"].retrieveRagContext(conv, newPrompt, update);
webSearchResults = (await RAGs["webSearch"].retrieveRagContext(
conv,
newPrompt,
update
)) as RagContextWebSearch;
}

messages[messages.length - 1].ragContext = webSearchResults;

let pdfSearchResults: PdfSearch | undefined;
let pdfSearchResults: RagContext | undefined;
const pdfSearch = await collections.files.findOne({ filename: `${convId.toString()}-pdf` });
if (pdfSearch) {
pdfSearchResults = await RAGs["pdfChat"].retrieveRagContext(conv, newPrompt, update);
Expand Down Expand Up @@ -274,7 +279,7 @@ export async function POST({ request, locals, params, getClientAddress }) {
{
from: "assistant",
content: output.token.text.trimStart(),
webSearch: webSearchResults,
ragContext: webSearchResults,
updates: updates,
id: (responseId as Message["id"]) || crypto.randomUUID(),
createdAt: new Date(),
Expand Down

0 comments on commit f8b2ec5

Please sign in to comment.