Skip to content

Commit dede259

Browse files
authored
Merge pull request #1115 from vinodkiran/FEATURE/output-parsers
New Feature - Output Parsers
2 parents 130f24c + ec76b3c commit dede259

File tree

27 files changed

+2174
-768
lines changed

27 files changed

+2174
-768
lines changed

packages/components/nodes/chains/LLMChain/LLMChain.ts

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils'
33
import { LLMChain } from 'langchain/chains'
44
import { BaseLanguageModel } from 'langchain/base_language'
55
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
6+
import { BaseOutputParser } from 'langchain/schema/output_parser'
7+
import { formatResponse, injectOutputParser } from '../../outputparsers/OutputParserHelpers'
8+
import { BaseLLMOutputParser } from 'langchain/schema/output_parser'
9+
import { OutputFixingParser } from 'langchain/output_parsers'
610

711
class LLMChain_Chains implements INode {
812
label: string
@@ -15,11 +19,12 @@ class LLMChain_Chains implements INode {
1519
description: string
1620
inputs: INodeParams[]
1721
outputs: INodeOutputsValue[]
22+
outputParser: BaseOutputParser
1823

1924
constructor() {
2025
this.label = 'LLM Chain'
2126
this.name = 'llmChain'
22-
this.version = 1.0
27+
this.version = 3.0
2328
this.type = 'LLMChain'
2429
this.icon = 'chain.svg'
2530
this.category = 'Chains'
@@ -36,6 +41,12 @@ class LLMChain_Chains implements INode {
3641
name: 'prompt',
3742
type: 'BasePromptTemplate'
3843
},
44+
{
45+
label: 'Output Parser',
46+
name: 'outputParser',
47+
type: 'BaseLLMOutputParser',
48+
optional: true
49+
},
3950
{
4051
label: 'Chain Name',
4152
name: 'chainName',
@@ -63,12 +74,29 @@ class LLMChain_Chains implements INode {
6374
const prompt = nodeData.inputs?.prompt
6475
const output = nodeData.outputs?.output as string
6576
const promptValues = prompt.promptValues as ICommonObject
66-
77+
const llmOutputParser = nodeData.inputs?.outputParser as BaseOutputParser
78+
this.outputParser = llmOutputParser
79+
if (llmOutputParser) {
80+
let autoFix = (llmOutputParser as any).autoFix
81+
if (autoFix === true) {
82+
this.outputParser = OutputFixingParser.fromLLM(model, llmOutputParser)
83+
}
84+
}
6785
if (output === this.name) {
68-
const chain = new LLMChain({ llm: model, prompt, verbose: process.env.DEBUG === 'true' ? true : false })
86+
const chain = new LLMChain({
87+
llm: model,
88+
outputParser: this.outputParser as BaseLLMOutputParser<string | object>,
89+
prompt,
90+
verbose: process.env.DEBUG === 'true'
91+
})
6992
return chain
7093
} else if (output === 'outputPrediction') {
71-
const chain = new LLMChain({ llm: model, prompt, verbose: process.env.DEBUG === 'true' ? true : false })
94+
const chain = new LLMChain({
95+
llm: model,
96+
outputParser: this.outputParser as BaseLLMOutputParser<string | object>,
97+
prompt,
98+
verbose: process.env.DEBUG === 'true'
99+
})
72100
const inputVariables = chain.prompt.inputVariables as string[] // ["product"]
73101
const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData)
74102
// eslint-disable-next-line no-console
@@ -84,10 +112,15 @@ class LLMChain_Chains implements INode {
84112
}
85113
}
86114

87-
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
115+
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
88116
const inputVariables = nodeData.instance.prompt.inputVariables as string[] // ["product"]
89117
const chain = nodeData.instance as LLMChain
90-
const promptValues = nodeData.inputs?.prompt.promptValues as ICommonObject
118+
let promptValues: ICommonObject | undefined = nodeData.inputs?.prompt.promptValues as ICommonObject
119+
const outputParser = nodeData.inputs?.outputParser as BaseOutputParser
120+
if (!this.outputParser && outputParser) {
121+
this.outputParser = outputParser
122+
}
123+
promptValues = injectOutputParser(this.outputParser, chain, promptValues)
91124
const res = await runPrediction(inputVariables, chain, input, promptValues, options, nodeData)
92125
// eslint-disable-next-line no-console
93126
console.log('\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m')
@@ -99,9 +132,9 @@ class LLMChain_Chains implements INode {
99132

100133
const runPrediction = async (
101134
inputVariables: string[],
102-
chain: LLMChain,
135+
chain: LLMChain<string | object>,
103136
input: string,
104-
promptValuesRaw: ICommonObject,
137+
promptValuesRaw: ICommonObject | undefined,
105138
options: ICommonObject,
106139
nodeData: INodeData
107140
) => {
@@ -135,10 +168,10 @@ const runPrediction = async (
135168
if (isStreaming) {
136169
const handler = new CustomChainHandler(socketIO, socketIOClientId)
137170
const res = await chain.call(options, [loggerHandler, handler, ...callbacks])
138-
return res?.text
171+
return formatResponse(res?.text)
139172
} else {
140173
const res = await chain.call(options, [loggerHandler, ...callbacks])
141-
return res?.text
174+
return formatResponse(res?.text)
142175
}
143176
} else if (seen.length === 1) {
144177
// If one inputVariable is not specify, use input (user's question) as value
@@ -151,10 +184,10 @@ const runPrediction = async (
151184
if (isStreaming) {
152185
const handler = new CustomChainHandler(socketIO, socketIOClientId)
153186
const res = await chain.call(options, [loggerHandler, handler, ...callbacks])
154-
return res?.text
187+
return formatResponse(res?.text)
155188
} else {
156189
const res = await chain.call(options, [loggerHandler, ...callbacks])
157-
return res?.text
190+
return formatResponse(res?.text)
158191
}
159192
} else {
160193
throw new Error(`Please provide Prompt Values for: ${seen.join(', ')}`)
@@ -163,10 +196,10 @@ const runPrediction = async (
163196
if (isStreaming) {
164197
const handler = new CustomChainHandler(socketIO, socketIOClientId)
165198
const res = await chain.run(input, [loggerHandler, handler, ...callbacks])
166-
return res
199+
return formatResponse(res)
167200
} else {
168201
const res = await chain.run(input, [loggerHandler, ...callbacks])
169-
return res
202+
return formatResponse(res)
170203
}
171204
}
172205
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import { getBaseClasses, INode, INodeData, INodeParams } from '../../../src'
2+
import { BaseOutputParser } from 'langchain/schema/output_parser'
3+
import { CommaSeparatedListOutputParser } from 'langchain/output_parsers'
4+
import { CATEGORY } from '../OutputParserHelpers'
5+
6+
class CSVListOutputParser implements INode {
7+
label: string
8+
name: string
9+
version: number
10+
description: string
11+
type: string
12+
icon: string
13+
category: string
14+
baseClasses: string[]
15+
inputs: INodeParams[]
16+
credential: INodeParams
17+
18+
constructor() {
19+
this.label = 'CSV Output Parser'
20+
this.name = 'csvOutputParser'
21+
this.version = 1.0
22+
this.type = 'CSVListOutputParser'
23+
this.description = 'Parse the output of an LLM call as a comma-separated list of values'
24+
this.icon = 'csv.png'
25+
this.category = CATEGORY
26+
this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)]
27+
this.inputs = [
28+
{
29+
label: 'Autofix',
30+
name: 'autofixParser',
31+
type: 'boolean',
32+
optional: true,
33+
description: 'In the event that the first call fails, will make another call to the model to fix any errors.'
34+
}
35+
]
36+
}
37+
38+
async init(nodeData: INodeData): Promise<any> {
39+
const autoFix = nodeData.inputs?.autofixParser as boolean
40+
41+
const commaSeparatedListOutputParser = new CommaSeparatedListOutputParser()
42+
Object.defineProperty(commaSeparatedListOutputParser, 'autoFix', {
43+
enumerable: true,
44+
configurable: true,
45+
writable: true,
46+
value: autoFix
47+
})
48+
return commaSeparatedListOutputParser
49+
}
50+
}
51+
52+
module.exports = { nodeClass: CSVListOutputParser }
8.3 KB
Loading
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { getBaseClasses, INode, INodeData, INodeParams } from '../../../src'
2+
import { BaseOutputParser } from 'langchain/schema/output_parser'
3+
import { CustomListOutputParser as LangchainCustomListOutputParser } from 'langchain/output_parsers'
4+
import { CATEGORY } from '../OutputParserHelpers'
5+
6+
class CustomListOutputParser implements INode {
7+
label: string
8+
name: string
9+
version: number
10+
description: string
11+
type: string
12+
icon: string
13+
category: string
14+
baseClasses: string[]
15+
inputs: INodeParams[]
16+
credential: INodeParams
17+
18+
constructor() {
19+
this.label = 'Custom List Output Parser'
20+
this.name = 'customListOutputParser'
21+
this.version = 1.0
22+
this.type = 'CustomListOutputParser'
23+
this.description = 'Parse the output of an LLM call as a list of values.'
24+
this.icon = 'list.png'
25+
this.category = CATEGORY
26+
this.baseClasses = [this.type, ...getBaseClasses(BaseOutputParser)]
27+
this.inputs = [
28+
{
29+
label: 'Length',
30+
name: 'length',
31+
type: 'number',
32+
default: 5,
33+
step: 1,
34+
description: 'Number of values to return'
35+
},
36+
{
37+
label: 'Separator',
38+
name: 'separator',
39+
type: 'string',
40+
description: 'Separator between values',
41+
default: ','
42+
},
43+
{
44+
label: 'Autofix',
45+
name: 'autofixParser',
46+
type: 'boolean',
47+
optional: true,
48+
description: 'In the event that the first call fails, will make another call to the model to fix any errors.'
49+
}
50+
]
51+
}
52+
53+
async init(nodeData: INodeData): Promise<any> {
54+
const separator = nodeData.inputs?.separator as string
55+
const lengthStr = nodeData.inputs?.length as string
56+
const autoFix = nodeData.inputs?.autofixParser as boolean
57+
let length = 5
58+
if (lengthStr) length = parseInt(lengthStr, 10)
59+
60+
const parser = new LangchainCustomListOutputParser({ length: length, separator: separator })
61+
Object.defineProperty(parser, 'autoFix', {
62+
enumerable: true,
63+
configurable: true,
64+
writable: true,
65+
value: autoFix
66+
})
67+
return parser
68+
}
69+
}
70+
71+
module.exports = { nodeClass: CustomListOutputParser }
4.88 KB
Loading
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import { BaseOutputParser } from 'langchain/schema/output_parser'
2+
import { LLMChain } from 'langchain/chains'
3+
import { BaseLanguageModel } from 'langchain/base_language'
4+
import { ICommonObject } from '../../src'
5+
import { ChatPromptTemplate, FewShotPromptTemplate, PromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts'
6+
7+
export const CATEGORY = 'Output Parsers'
8+
9+
export const formatResponse = (response: string | object): string | object => {
10+
if (typeof response === 'object') {
11+
return { json: response }
12+
}
13+
return response
14+
}
15+
16+
export const injectOutputParser = (
17+
outputParser: BaseOutputParser<unknown>,
18+
chain: LLMChain<string, BaseLanguageModel>,
19+
promptValues: ICommonObject | undefined = undefined
20+
) => {
21+
if (outputParser && chain.prompt) {
22+
const formatInstructions = outputParser.getFormatInstructions()
23+
if (chain.prompt instanceof PromptTemplate) {
24+
let pt = chain.prompt
25+
pt.template = pt.template + '\n{format_instructions}'
26+
chain.prompt.partialVariables = { format_instructions: formatInstructions }
27+
} else if (chain.prompt instanceof ChatPromptTemplate) {
28+
let pt = chain.prompt
29+
pt.promptMessages.forEach((msg) => {
30+
if (msg instanceof SystemMessagePromptTemplate) {
31+
;(msg.prompt as any).partialVariables = { format_instructions: outputParser.getFormatInstructions() }
32+
;(msg.prompt as any).template = ((msg.prompt as any).template + '\n{format_instructions}') as string
33+
}
34+
})
35+
} else if (chain.prompt instanceof FewShotPromptTemplate) {
36+
chain.prompt.examplePrompt.partialVariables = { format_instructions: formatInstructions }
37+
chain.prompt.examplePrompt.template = chain.prompt.examplePrompt.template + '\n{format_instructions}'
38+
}
39+
40+
chain.prompt.inputVariables.push('format_instructions')
41+
if (promptValues) {
42+
promptValues = { ...promptValues, format_instructions: outputParser.getFormatInstructions() }
43+
}
44+
}
45+
return promptValues
46+
}

0 commit comments

Comments
 (0)