Skip to content

Commit

Permalink
feat: support app access token (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Jun 2, 2024
1 parent 5701267 commit 446afb1
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/app/api/auth/[...nextauth]/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { type NextRequest, NextResponse } from 'next/server';

declare module 'next-auth' {
interface User {
role?: 'anonymous' | 'admin';
role?: 'anonymous' | 'user' | 'admin' | 'app' | 'cronjob';
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/app/api/v1/chats/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ const ChatRequest = z.object({
}).array(),
sessionId: z.string().optional(),
name: z.string().optional(),
namespaces: z.string().array().optional(),
index: z.string().optional(),
// TODO: using engine name instead.
engine: z.number().int().optional(),
regenerate: z.boolean().optional(),
messageId: z.coerce.number().int().optional(),
Expand All @@ -30,12 +30,12 @@ const DEFAULT_CHAT_TITLE = 'Untitled';

export const POST = defineHandler({
body: ChatRequest,
auth: 'anonymous',
auth: ['anonymous', 'app'],
}, async ({
body,
auth,
}) => {
const userId = auth.user.id!;
const userId = auth?.user?.id!;
let {
index: indexName = 'default',
messages,
Expand Down
6 changes: 6 additions & 0 deletions src/core/db/schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ export type JsonPrimitive = boolean | number | string | null;

export type JsonValue = JsonArray | JsonObject | JsonPrimitive;

export interface AppAccessToken {
app_id: string;
token: string;
}

export interface AuthenticationProvider {
config: Json;
enabled: number;
Expand Down Expand Up @@ -201,6 +206,7 @@ export interface Status {
}

export interface DB {
app_access_token: AppAccessToken;
authentication_provider: AuthenticationProvider;
chat: Chat;
chat_engine: ChatEngine;
Expand Down
11 changes: 11 additions & 0 deletions src/core/repositories/app_access_token.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { getDb } from '@/core/db';

export async function getAppAccessToken (rawAccessToken: string) {
return await getDb()
.selectFrom('app_access_token')
.selectAll()
.where('token', '=', eb =>
eb.fn('SHA2', [eb.val(rawAccessToken), eb.val(256)])
)
.executeTakeFirst();
}
4 changes: 3 additions & 1 deletion src/core/services/llamaindex/retrieving.ts
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,16 @@ export class LlamaindexRetrieverWrapper implements BaseRetriever {
const detailedChunks = await this.retrieveService.extendResultDetails(chunks);

return detailedChunks.map(chunk => {
const { url, ...restMetadata } = chunk.document_metadata;
return {
node: new TextNode({
id_: chunk.document_chunk_node_id,
text: chunk.text,
metadata: {
//// MARK: we don't need the metadata from extractors, they are for embedding.
// ...chunk.metadata,
sourceUri: chunk.document_uri,
sourceUri: chunk.document_uri || url,
...restMetadata,
},
relationships: Object.fromEntries(Object.entries(chunk.relationships).map(([k, v]) => {
return [k, { nodeId: v.chunk_node_id, metadata: v.metadata } satisfies RelatedNodeInfo];
Expand Down
7 changes: 7 additions & 0 deletions src/lib/encrypt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import crypto from 'crypto';

export function generateSHA256Hash(data: string) {
const hash = crypto.createHash('sha256');
hash.update(data);
return hash.digest('hex');
}
7 changes: 7 additions & 0 deletions src/lib/errors/api_errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ export const AUTH_REQUIRE_AUTHED_ERROR = APIError.new('Require authentication',

export const AUTH_FORBIDDEN_ERROR = APIError.new('Forbidden', 403);

/**
* The third-party application related errors
*/

export const APP_AUTH_REQUIRE_AUTH_TOKEN_ERROR = APIError.new('Require Access Token for the third-party application', 401);

export const APP_AUTH_INVALID_AUTH_TOKEN_ERROR = APIError.new('Invalid Access Token for the third-party application', 401);

/**
* CronJob related errors
Expand Down
136 changes: 101 additions & 35 deletions src/lib/next/handler.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import {auth as authFn} from '@/app/api/auth/[...nextauth]/auth';
import {getAppAccessToken} from "@/core/repositories/app_access_token";
import {generateSHA256Hash} from "@/lib/encrypt";
import {
APIError,
APIError, APP_AUTH_INVALID_AUTH_TOKEN_ERROR, APP_AUTH_REQUIRE_AUTH_TOKEN_ERROR,
AUTH_FORBIDDEN_ERROR,
AUTH_REQUIRE_AUTHED_ERROR,
CRONJOB_INVALID_AUTH_TOKEN_ERROR,
Expand All @@ -9,12 +11,13 @@ import {
import {parseBody, parseParams, parseSearchParams} from '@/lib/next/parse';
import {type RouteProps} from '@/lib/next/types';
import type {Rewrite} from '@/lib/type-utils';
import {DateTime} from "luxon";
import type {Session, User} from 'next-auth';
import {isNextRouterError} from 'next/dist/client/components/is-next-router-error';
import {type NextRequest, NextResponse} from 'next/server';
import {z, ZodError, type ZodObject, type ZodType} from 'zod';

export type AppAuthType = 'anonymous' | 'user' | 'admin' | 'cronjob' | void;
export type AppAuthType = 'anonymous' | 'user' | 'admin' | 'cronjob' | 'app' | void;

export function defineHandler<
ZSearchParams extends ZodObject<any> | void = void,
Expand All @@ -27,7 +30,7 @@ export function defineHandler<
searchParams?: ZSearchParams | void;
params?: ZParams | void;
body?: ZBody | void;
auth?: Auth | void;
auth?: Auth | Auth[] | void;
testOnly?: boolean;
},
handler: (
Expand All @@ -53,38 +56,10 @@ export function defineHandler<
let searchParams: any;
let params: any;
let body: any;
let session: Session | undefined;

let auth: Session | null = null;
if (options.auth) {
auth = await authFn();
if (options.auth === 'cronjob') {
const authHeader = request.headers.get('authorization');
if (!authHeader) {
return CRONJOB_REQUIRE_AUTH_TOKEN_ERROR.toResponse();
}
if (authHeader !== `Bearer ${process.env.CRON_SECRET}`) {
return CRONJOB_INVALID_AUTH_TOKEN_ERROR.toResponse();
}
} else if (options.auth) {
if (!auth?.user) {
return AUTH_REQUIRE_AUTHED_ERROR.toResponse();
}
switch (options.auth) {
case 'anonymous':
// PASS
break
case 'user':
if (auth.user.role === 'anonymous') {
return AUTH_REQUIRE_AUTHED_ERROR.toResponse();
}
break
case 'admin':
if ( auth.user.role !== 'admin') {
return AUTH_FORBIDDEN_ERROR.toResponse();
}
break
}
}
session = await loadSession(request, options.auth);
}

if (options.params) {
Expand All @@ -101,7 +76,7 @@ export function defineHandler<
params,
searchParams,
body,
auth: auth as any,
auth: session as any,
request,
ctx,
});
Expand All @@ -119,7 +94,9 @@ export function defineHandler<
throw e;
}

if (e instanceof ZodError) {
if (e instanceof APIError) {
return e.toResponse();
} else if (e instanceof ZodError) {
return NextResponse.json({
name: 'ZodError',
message: e.message,
Expand All @@ -135,3 +112,92 @@ export function defineHandler<
}
};
}

async function loadSession<Auth extends AppAuthType>(request: NextRequest, auth: Auth | Auth[]): Promise<Session | undefined> {
if (!auth) {
return;
}
if (Array.isArray(auth)) {
for (let i = 0; i < auth.length; i++) {
try {
return await verifyAuth(auth[i], request);
} catch (e) {
if (i !== auth.length - 1) {
continue;
}
if (e instanceof APIError) {
continue;
}
throw e;
}
}
throw AUTH_REQUIRE_AUTHED_ERROR;
} else {
return await verifyAuth(auth, request);
}
}

async function verifyAuth(authType: AppAuthType, request: NextRequest): Promise<Session> {
const session = await authFn();
switch (authType) {
case 'cronjob':
return verifyCronJobAuth(request);
case 'app':
return verifyAppAuth(request);
case 'anonymous':
case 'user':
case 'admin':
return verifyUserAuth(session, authType);
default:
throw new Error('Invalid auth type');
}
}

async function verifyCronJobAuth(request: NextRequest): Promise<Session> {
const accessToken = request.headers.get('authorization')?.replace('Bearer ', '')?.trim();
if (!accessToken) {
throw CRONJOB_REQUIRE_AUTH_TOKEN_ERROR;
}
if (accessToken !== process.env.CRON_SECRET) {
throw CRONJOB_INVALID_AUTH_TOKEN_ERROR;
}
return {
user: {
id: `cronjob-${DateTime.now().toISO()}`,
role: 'cronjob',
},
expires: DateTime.now().plus({ day: 1 }).toISO(),
};
}

async function verifyAppAuth(request: NextRequest): Promise<Session> {
const accessToken = request.headers.get('authorization')?.replace('Bearer ', '')?.trim();
if (!accessToken) {
throw APP_AUTH_REQUIRE_AUTH_TOKEN_ERROR;
}
const aat = await getAppAccessToken(accessToken);
if (!aat || generateSHA256Hash(accessToken) !== aat?.token) {
throw APP_AUTH_INVALID_AUTH_TOKEN_ERROR;
}
return {
user: {
id: aat.app_id,
role: 'app',
},
expires: DateTime.now().plus({ month: 12 }).toISO(),
}
}

async function verifyUserAuth(session: Session | undefined, requiredRole: AppAuthType) {
session = await authFn();
if (!session?.user) {
throw AUTH_REQUIRE_AUTHED_ERROR;
}
if (requiredRole === 'user' && session.user.role === 'anonymous') {
throw AUTH_REQUIRE_AUTHED_ERROR;
}
if (requiredRole === 'admin' && session.user.role !== 'admin') {
throw AUTH_FORBIDDEN_ERROR;
}
return session;
}

0 comments on commit 446afb1

Please sign in to comment.