Skip to content
This repository was archived by the owner on Sep 12, 2024. It is now read-only.

Commit 25aa9be

Browse files
committed
feat: impl extensions for langchain
1 parent 23cfb48 commit 25aa9be

File tree

8 files changed

+1091
-89
lines changed

8 files changed

+1091
-89
lines changed

package-lock.json

Lines changed: 1008 additions & 78 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,15 @@
4747
"@tensorflow/tfjs-node": "^4.2.0",
4848
"@types/node": "^18.15.5",
4949
"@types/semver": "^7.3.13",
50-
"axios": "^1.3.5",
50+
"axios": "*",
5151
"glob": "^9.3.4",
5252
"rimraf": "^4.4.1",
5353
"semver": "^7.3.8",
5454
"tsup": "^6.7.0",
5555
"tsx": "^3.12.6",
5656
"typescript": "^5.0.4",
57-
"vitest": "^0.29.8"
57+
"vitest": "^0.29.8",
58+
"langchain": "^0.0.56"
5859
},
5960
"dependencies": {
6061
"@llama-node/cli": "0.0.27"
@@ -68,4 +69,4 @@
6869
"@llama-node/core": "0.0.27",
6970
"@llama-node/llama-cpp": "0.0.27"
7071
}
71-
}
72+
}

src/extensions/langchain.ts

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import { AsyncCaller } from "langchain/dist/util/async_caller";
2+
import { Embeddings, type EmbeddingsParams } from "langchain/embeddings/base";
3+
import type { LLama } from "..";
4+
5+
export class LLamaEmbeddings implements Embeddings {
6+
caller: AsyncCaller;
7+
llm: LLama;
8+
9+
constructor(params: EmbeddingsParams, llm: LLama) {
10+
if ((params.maxConcurrency ?? 1) > 1) {
11+
console.warn(
12+
"maxConcurrency > 1 not officially supported for llama-node, use at your own risk"
13+
);
14+
}
15+
this.caller = new AsyncCaller(params);
16+
this.llm = llm;
17+
}
18+
19+
embedDocuments(documents: string[]): Promise<number[][]> {
20+
const promises = documents.map((doc) =>
21+
this.llm.getDefaultEmbeddings(doc)
22+
);
23+
return Promise.all(promises);
24+
}
25+
26+
embedQuery(document: string): Promise<number[]> {
27+
return this.llm.getDefaultEmbeddings(document);
28+
}
29+
}

src/index.ts

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ import { CompletionCallback } from "./llm";
22
import type { LLM } from "./llm";
33

44
export class LLama<
5-
Instance,
6-
LoadConfig,
7-
LLMInferenceArguments,
8-
LLMEmbeddingArguments,
9-
TokenizeArguments
5+
Instance = any,
6+
LoadConfig = any,
7+
LLMInferenceArguments = any,
8+
LLMEmbeddingArguments = any,
9+
TokenizeArguments = any
1010
> {
1111
llm: LLM<
1212
Instance,
@@ -48,6 +48,15 @@ export class LLama<
4848
}
4949
}
5050

51+
async getDefaultEmbeddings(text: string): Promise<number[]> {
52+
if (!this.llm.getDefaultEmbedding) {
53+
console.warn("getDefaultEmbedding not implemented for current LLM");
54+
return [];
55+
} else {
56+
return this.llm.getDefaultEmbedding(text);
57+
}
58+
}
59+
5160
async tokenize(content: TokenizeArguments): Promise<number[]> {
5261
if (!this.llm.tokenize) {
5362
console.warn("tokenize not implemented for current LLM");

src/llm.d.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export interface LLM<
77
LoadConfig,
88
LLMInferenceArguments,
99
LLMEmbeddingArguments,
10-
LLMTokenizeArguments,
10+
LLMTokenizeArguments
1111
> {
1212
readonly instance: Instance;
1313

@@ -20,5 +20,7 @@ export interface LLM<
2020

2121
getEmbedding?(params: LLMEmbeddingArguments): Promise<number[]>;
2222

23+
getDefaultEmbedding?(text: string): Promise<number[]>;
24+
2325
tokenize?(content: LLMTokenizeArguments): Promise<number[]>;
2426
}

src/llm/llama-cpp.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ export interface TokenizeArguments {
2121

2222
export class LLamaCpp
2323
implements
24-
LLM<LLama, LoadConfig, LlamaInvocation, LlamaInvocation, TokenizeArguments>
24+
LLM<
25+
LLama,
26+
LoadConfig,
27+
LlamaInvocation,
28+
LlamaInvocation,
29+
TokenizeArguments
30+
>
2531
{
2632
instance!: LLama;
2733

@@ -82,6 +88,18 @@ export class LLamaCpp
8288
});
8389
}
8490

91+
async getDefaultEmbedding(text: string): Promise<number[]> {
92+
return this.getEmbedding({
93+
nThreads: 4,
94+
nTokPredict: 1024,
95+
topK: 40,
96+
topP: 0.1,
97+
temp: 0.1,
98+
repeatPenalty: 1,
99+
prompt: text,
100+
});
101+
}
102+
85103
async tokenize(params: TokenizeArguments): Promise<number[]> {
86104
return new Promise<number[]>((res, rej) => {
87105
this.instance.tokenize(params.content, params.nCtx, (response) => {

src/llm/llama-rs.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ export class LLamaRS
7575
});
7676
}
7777

78+
async getDefaultEmbedding(text: string): Promise<number[]> {
79+
return this.getEmbedding({
80+
nThreads: 4,
81+
numPredict: 1024,
82+
topK: 40,
83+
topP: 0.1,
84+
temp: 0.1,
85+
repeatPenalty: 1,
86+
prompt: text,
87+
});
88+
}
89+
7890
async tokenize(params: string): Promise<number[]> {
7991
return new Promise<number[]>((res) => {
8092
this.instance.tokenize(params, (response) => {

tsup.config.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { defineConfig } from "tsup";
22

33
export default defineConfig({
4-
entry: ["src/index.ts", "src/llm/*.ts"],
4+
entry: ["src/index.ts", "src/llm/*.ts", "src/extensions/*.ts"],
5+
external: ["langchain"],
56
target: ["es2015"],
67
format: ["cjs", "esm"],
78
dts: true,

0 commit comments

Comments
 (0)