Skip to content

Commit

Permalink
Optimize and fix format (#9)
Browse files Browse the repository at this point in the history
* Update structure and fix tool functions

* Fix format

* Add scripts for test

* Convert useless code to temporary comment

* Optimize allocate in methods

* Fix scripts
  • Loading branch information
timursevimli authored Aug 25, 2024
1 parent 9ec3c56 commit c85579f
Show file tree
Hide file tree
Showing 5 changed files with 3,892 additions and 3,904 deletions.
32 changes: 11 additions & 21 deletions files/openai/tools/test-library.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,29 @@

const locations = [
{
location: 'San Francisco',
location: 'San Francisco, CA',
temperature: '72',
unit: 'fahrenheit',
age: '72',
},
{
location: 'Paris',
location: 'Paris, France',
temperature: '22',
unit: 'fahrenheit',
age: '22',
},
{
location: 'Tokyo',
location: 'Tokyo, Japan',
temperature: '10',
unit: 'celsius',
age: '10',
},
];

const getCurrentWeather = ({ location /*,unit = 'fahrenheit'*/ }) => {
const targetLocation = location.toLowerCase();
for (const expectedLocation of locations) {
const { location } = expectedLocation;
if (location.toLowerCase() === targetLocation) {
const result = { ...expectedLocation };
delete result.age;
const getCurrentWeather = ({ location: targetLocation }) => {
for (const { location, temperature, unit } of locations) {
if (location === targetLocation) {
const result = { location, temperature, unit };
return JSON.stringify(result);
}
}
Expand All @@ -38,13 +35,10 @@ const getCurrentWeather = ({ location /*,unit = 'fahrenheit'*/ }) => {
});
};

const getCurrentAge = ({ location /*unit = 'years'*/ }) => {
const targetLocation = location.toLowerCase();
for (const expectedLocation of locations) {
const { location } = expectedLocation;
if (location.toLowerCase() === targetLocation) {
const result = { ...expectedLocation };
delete result.temperature;
const getCurrentAge = ({ location: targetLocation }) => {
for (const { location, age, unit } of locations) {
if (location === targetLocation) {
const result = { location, age, unit };
return JSON.stringify(result);
}
}
Expand All @@ -57,8 +51,6 @@ const getCurrentAge = ({ location /*unit = 'years'*/ }) => {

const tools = [
{
// name: "get_current_weather",

fn: getCurrentWeather,
scope: this,
description: 'Get the current weather in a given location',
Expand All @@ -72,8 +64,6 @@ const tools = [
required: ['location'],
},
{
// name: "get_current_weather",

fn: getCurrentAge,
scope: this,
description: 'Get the current population average age in a given location',
Expand Down
45 changes: 18 additions & 27 deletions lib/huggingface/connector.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
'use strict';

const { HfInference } = require('@huggingface/inference');

const utils = require('./utils');
const { DEFAULT_MODELS } = require('./config.json');

const { tokens, custom } = utils;
const { /* tokens, */ language } = utils;

class Chat {
//temperature = 0.7, topP = 1, frequencyPenalty = 0
Expand All @@ -15,47 +14,39 @@ class Chat {
system,
model = DEFAULT_MODELS.completions,
tools,
maxTokens = 1000,
// maxTokens = 1000,
// maxPrice = 0.1,
}) {
this.hf = new HfInference(apiKey);
this.system = system;
this.model = model;
this.tools = tools;
this.maxTokens = maxTokens;
this.language = language(this.hf);
// this.maxTokens = maxTokens;
// this.maxPrice = maxPrice;

this.messages = [];
this.tokens = 0;
this.price = 0;
// this.messages = [];
// this.tokens = 0;
// this.price = 0;

// throw new Error(`Max ${maxTokens} tokens exceeded`);
}

async message({ text }) {
const tokenCount = await tokens.count({ text, model: this.model });
const { maxTokens, model } = this;

const increaseMaxTokens = tokens + tokenCount > maxTokens;
if (increaseMaxTokens) {
throw new Error(`Max ${this.maxTokens} tokens exceeded`);
}

const res = await custom(this.hf).generate({
text,
model,
messages: this.messages,
system: this.system,
tools: this.tools,
});
const { /* maxTokens, */ model } = this;
// const tokenCount = await tokens.count({ text, model });

if (res.error) return res.error.message;
// const increaseMaxTokens = tokens + tokenCount > maxTokens;
// if (increaseMaxTokens) {
// throw new Error(`Max ${this.maxTokens} tokens exceeded`);
// }
const res = await this.language.textGeneration(text, model);

this.messages = res.messages;
this.tokens += res.usage.total_tokens;
this.price += res.usage.total_price;
// this.messages = res.messages;
// this.tokens += res.usage.total_tokens;
// this.price += res.usage.total_price;

return res.message;
return res;
}

/*
Expand Down
8 changes: 5 additions & 3 deletions lib/openai/connector.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Chat {
this.system = system;
this.model = model;
this.tools = tools;
this.language = language(this.openai);
this.speech = speech(this.openai);
this.maxTokens = maxTokens;
// this.maxPrice = maxPrice;

Expand All @@ -42,7 +44,7 @@ class Chat {
throw new Error(`Max ${this.maxTokens} tokens exceeded`);
}

const res = await language(this.openai).generate({
const res = await this.language.generate({
text,
model,
messages: this.messages,
Expand Down Expand Up @@ -73,7 +75,7 @@ class Chat {
let inputText = text;

if (inputFilePath) {
inputText = await speech(this.openai).speechToText({
inputText = await this.speech.speechToText({
pathToFile: inputFilePath,
});
}
Expand Down Expand Up @@ -106,7 +108,7 @@ class Chat {
voice = DEFAULT_VOICE,
}) {
const start = measureTime();
await speech(this.openai).textToSpeech({
await this.speech.textToSpeech({
text: outputText,
pathToFile: outputFilePath,
voice,
Expand Down
Loading

0 comments on commit c85579f

Please sign in to comment.