@@ -3,6 +3,10 @@ import { getBaseClasses, handleEscapeCharacters } from '../../../src/utils'
33import { LLMChain } from 'langchain/chains'
44import { BaseLanguageModel } from 'langchain/base_language'
55import { ConsoleCallbackHandler , CustomChainHandler , additionalCallbacks } from '../../../src/handler'
6+ import { BaseOutputParser } from 'langchain/schema/output_parser'
7+ import { formatResponse , injectOutputParser } from '../../outputparsers/OutputParserHelpers'
8+ import { BaseLLMOutputParser } from 'langchain/schema/output_parser'
9+ import { OutputFixingParser } from 'langchain/output_parsers'
610
711class LLMChain_Chains implements INode {
812 label : string
@@ -15,11 +19,12 @@ class LLMChain_Chains implements INode {
1519 description : string
1620 inputs : INodeParams [ ]
1721 outputs : INodeOutputsValue [ ]
22+ outputParser : BaseOutputParser
1823
1924 constructor ( ) {
2025 this . label = 'LLM Chain'
2126 this . name = 'llmChain'
22- this . version = 1 .0
27+ this . version = 3 .0
2328 this . type = 'LLMChain'
2429 this . icon = 'chain.svg'
2530 this . category = 'Chains'
@@ -36,6 +41,12 @@ class LLMChain_Chains implements INode {
3641 name : 'prompt' ,
3742 type : 'BasePromptTemplate'
3843 } ,
44+ {
45+ label : 'Output Parser' ,
46+ name : 'outputParser' ,
47+ type : 'BaseLLMOutputParser' ,
48+ optional : true
49+ } ,
3950 {
4051 label : 'Chain Name' ,
4152 name : 'chainName' ,
@@ -63,12 +74,29 @@ class LLMChain_Chains implements INode {
6374 const prompt = nodeData . inputs ?. prompt
6475 const output = nodeData . outputs ?. output as string
6576 const promptValues = prompt . promptValues as ICommonObject
66-
77+ const llmOutputParser = nodeData . inputs ?. outputParser as BaseOutputParser
78+ this . outputParser = llmOutputParser
79+ if ( llmOutputParser ) {
80+ let autoFix = ( llmOutputParser as any ) . autoFix
81+ if ( autoFix === true ) {
82+ this . outputParser = OutputFixingParser . fromLLM ( model , llmOutputParser )
83+ }
84+ }
6785 if ( output === this . name ) {
68- const chain = new LLMChain ( { llm : model , prompt, verbose : process . env . DEBUG === 'true' ? true : false } )
86+ const chain = new LLMChain ( {
87+ llm : model ,
88+ outputParser : this . outputParser as BaseLLMOutputParser < string | object > ,
89+ prompt,
90+ verbose : process . env . DEBUG === 'true'
91+ } )
6992 return chain
7093 } else if ( output === 'outputPrediction' ) {
71- const chain = new LLMChain ( { llm : model , prompt, verbose : process . env . DEBUG === 'true' ? true : false } )
94+ const chain = new LLMChain ( {
95+ llm : model ,
96+ outputParser : this . outputParser as BaseLLMOutputParser < string | object > ,
97+ prompt,
98+ verbose : process . env . DEBUG === 'true'
99+ } )
72100 const inputVariables = chain . prompt . inputVariables as string [ ] // ["product"]
73101 const res = await runPrediction ( inputVariables , chain , input , promptValues , options , nodeData )
74102 // eslint-disable-next-line no-console
@@ -84,10 +112,15 @@ class LLMChain_Chains implements INode {
84112 }
85113 }
86114
87- async run ( nodeData : INodeData , input : string , options : ICommonObject ) : Promise < string > {
115+ async run ( nodeData : INodeData , input : string , options : ICommonObject ) : Promise < string | object > {
88116 const inputVariables = nodeData . instance . prompt . inputVariables as string [ ] // ["product"]
89117 const chain = nodeData . instance as LLMChain
90- const promptValues = nodeData . inputs ?. prompt . promptValues as ICommonObject
118+ let promptValues : ICommonObject | undefined = nodeData . inputs ?. prompt . promptValues as ICommonObject
119+ const outputParser = nodeData . inputs ?. outputParser as BaseOutputParser
120+ if ( ! this . outputParser && outputParser ) {
121+ this . outputParser = outputParser
122+ }
123+ promptValues = injectOutputParser ( this . outputParser , chain , promptValues )
91124 const res = await runPrediction ( inputVariables , chain , input , promptValues , options , nodeData )
92125 // eslint-disable-next-line no-console
93126 console . log ( '\x1b[93m\x1b[1m\n*****FINAL RESULT*****\n\x1b[0m\x1b[0m' )
@@ -99,9 +132,9 @@ class LLMChain_Chains implements INode {
99132
100133const runPrediction = async (
101134 inputVariables : string [ ] ,
102- chain : LLMChain ,
135+ chain : LLMChain < string | object > ,
103136 input : string ,
104- promptValuesRaw : ICommonObject ,
137+ promptValuesRaw : ICommonObject | undefined ,
105138 options : ICommonObject ,
106139 nodeData : INodeData
107140) => {
@@ -135,10 +168,10 @@ const runPrediction = async (
135168 if ( isStreaming ) {
136169 const handler = new CustomChainHandler ( socketIO , socketIOClientId )
137170 const res = await chain . call ( options , [ loggerHandler , handler , ...callbacks ] )
138- return res ?. text
171+ return formatResponse ( res ?. text )
139172 } else {
140173 const res = await chain . call ( options , [ loggerHandler , ...callbacks ] )
141- return res ?. text
174+ return formatResponse ( res ?. text )
142175 }
143176 } else if ( seen . length === 1 ) {
144177 // If one inputVariable is not specify, use input (user's question) as value
@@ -151,10 +184,10 @@ const runPrediction = async (
151184 if ( isStreaming ) {
152185 const handler = new CustomChainHandler ( socketIO , socketIOClientId )
153186 const res = await chain . call ( options , [ loggerHandler , handler , ...callbacks ] )
154- return res ?. text
187+ return formatResponse ( res ?. text )
155188 } else {
156189 const res = await chain . call ( options , [ loggerHandler , ...callbacks ] )
157- return res ?. text
190+ return formatResponse ( res ?. text )
158191 }
159192 } else {
160193 throw new Error ( `Please provide Prompt Values for: ${ seen . join ( ', ' ) } ` )
@@ -163,10 +196,10 @@ const runPrediction = async (
163196 if ( isStreaming ) {
164197 const handler = new CustomChainHandler ( socketIO , socketIOClientId )
165198 const res = await chain . run ( input , [ loggerHandler , handler , ...callbacks ] )
166- return res
199+ return formatResponse ( res )
167200 } else {
168201 const res = await chain . run ( input , [ loggerHandler , ...callbacks ] )
169- return res
202+ return formatResponse ( res )
170203 }
171204 }
172205}
0 commit comments