From d67864dedca6bb5283900d010d47bfd586746f33 Mon Sep 17 00:00:00 2001 From: Kun Chen Date: Sat, 19 Aug 2023 22:20:01 -0700 Subject: [PATCH] support object fields that are array of primitives --- .changeset/new-meals-wash.md | 5 +++ README.md | 4 +- src/decorators.ts | 24 +++++++---- tests/decorators.test.ts | 81 ++++++++++++++++++++++++++++++++++++ tests/session.test.ts | 46 -------------------- 5 files changed, 104 insertions(+), 56 deletions(-) create mode 100644 .changeset/new-meals-wash.md create mode 100644 tests/decorators.test.ts diff --git a/.changeset/new-meals-wash.md b/.changeset/new-meals-wash.md new file mode 100644 index 0000000..40f3551 --- /dev/null +++ b/.changeset/new-meals-wash.md @@ -0,0 +1,5 @@ +--- +'function-gpt': minor +--- + +support object fields that are array of primitives diff --git a/README.md b/README.md index 11e9651..db296c0 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ class BrowseSession extends ChatGPTSession { // Define the type of the input parameter for functions above. class BrowseParams { // Decorate each field with @gptObjectField to provide necessary metadata. - @gptObjectField('string', 'url of the web page to browse', true) - public url: string = ''; + @gptObjectField('string', 'url of the web page to browse') + public url!: string; } const session = new BrowseSession(); diff --git a/src/decorators.ts b/src/decorators.ts index e263249..f3d8eba 100644 --- a/src/decorators.ts +++ b/src/decorators.ts @@ -31,9 +31,9 @@ export function gptFunction(description: string, inputType: new () => unknown) { } export function gptObjectField( - type: 'string' | 'number' | 'boolean' | (new () => unknown) | [new () => unknown], + type: 'string' | 'number' | 'boolean' | ['string' | 'number' | 'boolean'] | (new () => unknown) | [new () => unknown], description: string, - required = true, + optional = false, ) { return function (target: object, propertyKey: string) { const ctor = target.constructor as new () => unknown; @@ -51,38 +51,46 @@ export function gptObjectField( name: propertyKey, description, type: { type: 'string' }, - required, + required: !optional, }); } else if (type === 'number') { metadata.fields.push({ name: propertyKey, description, type: { type: 'number' }, - required, + required: !optional, }); } else if (type === 'boolean') { metadata.fields.push({ name: propertyKey, description, type: { type: 'boolean' }, - required, + required: !optional, }); } else if (Array.isArray(type)) { + const elementType = type[0]; metadata.fields.push({ name: propertyKey, description, type: { type: 'array', - elementType: GPT_TYPE_METADATA.get(type[0]) as GPTTypeMetadata, + elementType: + elementType === 'string' + ? { type: 'string' } + : elementType === 'number' + ? { type: 'number' } + : elementType === 'boolean' + ? { type: 'boolean' } + : (GPT_TYPE_METADATA.get(elementType) as GPTTypeMetadata), }, - required, + required: !optional, }); } else if (typeof type === 'function') { metadata.fields.push({ name: propertyKey, description, type: GPT_TYPE_METADATA.get(type) as GPTTypeMetadata, - required, + required: !optional, }); } diff --git a/tests/decorators.test.ts b/tests/decorators.test.ts new file mode 100644 index 0000000..4c46bcc --- /dev/null +++ b/tests/decorators.test.ts @@ -0,0 +1,81 @@ +import { expect, test } from 'vitest'; + +import { ChatGPTSession, gptFunction, gptObjectField } from '../index.js'; + +process.env.OPENAI_API_KEY = 'test'; + +test('basic function schema is generated correctly', async () => { + class TestFuncInput { + @gptObjectField('string', 'this is a test string', true) + public testString!: string; + + @gptObjectField('number', 'this is a test number', false) + public testNumber!: number; + } + + class TestSession extends ChatGPTSession { + @gptFunction('this is a test function', TestFuncInput) + testFunc(params: TestFuncInput) { + return params; + } + } + + const testSession = new TestSession(); + const schema = testSession.getFunctionSchema(); + expect(schema).toEqual([ + { + name: 'testFunc', + description: 'this is a test function', + parameters: { + type: 'object', + properties: { + testString: { + type: 'string', + description: 'this is a test string', + }, + testNumber: { + type: 'number', + description: 'this is a test number', + }, + }, + required: ['testNumber'], + }, + }, + ]); +}); + +test('input parameter can be an array of strings', () => { + class TestParam { + @gptObjectField(['string'], 'test words') + words!: string[]; + } + + class TestSession extends ChatGPTSession { + @gptFunction('this is a test function', TestParam) + testFunc(params: TestParam) { + return params; + } + } + + const testSession = new TestSession(); + const schema = testSession.getFunctionSchema(); + expect(schema).toEqual([ + { + name: 'testFunc', + description: 'this is a test function', + parameters: { + type: 'object', + properties: { + words: { + type: 'array', + items: { + type: 'string', + }, + description: 'test words', + }, + }, + required: ['words'], + }, + }, + ]); +}); diff --git a/tests/session.test.ts b/tests/session.test.ts index d868ed7..8fd47b3 100644 --- a/tests/session.test.ts +++ b/tests/session.test.ts @@ -57,52 +57,6 @@ afterEach(() => { vi.clearAllMocks(); }); -test('function schema is generated correctly', async () => { - class TestFuncInput { - @gptObjectField('string', 'this is a test string', true) - public testString: string = ''; - - @gptObjectField('number', 'this is a test number', false) - public testNumber: number = 0; - } - - class TestSession extends ChatGPTSession { - constructor() { - super({ - apiKey: 'test', - }); - } - - @gptFunction('this is a test function', TestFuncInput) - testFunc(params: TestFuncInput) { - return params; - } - } - - const testSession = new TestSession(); - const schema = testSession.getFunctionSchema(); - expect(schema).toEqual([ - { - name: 'testFunc', - description: 'this is a test function', - parameters: { - type: 'object', - properties: { - testString: { - type: 'string', - description: 'this is a test string', - }, - testNumber: { - type: 'number', - description: 'this is a test number', - }, - }, - required: ['testString'], - }, - }, - ]); -}); - const fetch = vi.fn().mockImplementation(() => Promise.resolve()); test('function calling should work', async () => {