From bddc8d7f0fe58904e0064debb49fc340732b8e4a Mon Sep 17 00:00:00 2001 From: Dmitriy Tselinko Date: Wed, 15 May 2024 10:54:47 +0200 Subject: [PATCH] Huggingface refactor PR-URL: https://github.com/Universal-Code-Modules/universal-llm/pull/3 --- lib/huggingface/config.json | 42 ++ lib/huggingface/connector.js | 543 ++++-------------------- lib/huggingface/utils/audio.js | 50 +++ lib/huggingface/utils/computerVision.js | 94 ++++ lib/huggingface/utils/custom.js | 59 +++ lib/huggingface/utils/index.js | 10 + lib/huggingface/utils/language.js | 158 +++++++ lib/huggingface/utils/multimodal.js | 53 +++ lib/huggingface/utils/tabular.js | 59 +++ test/huggingface.js | 439 ------------------- test/huggingface/audio.js | 66 +++ test/huggingface/computerVision.js | 129 ++++++ test/huggingface/custom.js | 49 +++ test/huggingface/language.js | 181 ++++++++ test/huggingface/multimodal.js | 68 +++ test/huggingface/tabular.js | 62 +++ 16 files changed, 1165 insertions(+), 897 deletions(-) create mode 100644 lib/huggingface/config.json create mode 100644 lib/huggingface/utils/audio.js create mode 100644 lib/huggingface/utils/computerVision.js create mode 100644 lib/huggingface/utils/custom.js create mode 100644 lib/huggingface/utils/index.js create mode 100644 lib/huggingface/utils/language.js create mode 100644 lib/huggingface/utils/multimodal.js create mode 100644 lib/huggingface/utils/tabular.js delete mode 100644 test/huggingface.js create mode 100644 test/huggingface/audio.js create mode 100644 test/huggingface/computerVision.js create mode 100644 test/huggingface/custom.js create mode 100644 test/huggingface/language.js create mode 100644 test/huggingface/multimodal.js create mode 100644 test/huggingface/tabular.js diff --git a/lib/huggingface/config.json b/lib/huggingface/config.json new file mode 100644 index 0000000..7566a45 --- /dev/null +++ b/lib/huggingface/config.json @@ -0,0 +1,42 @@ +{ + "DEFAULT_MODELS": { + "language": { + "fillMask": "bert-base-uncased", + "summarization": "facebook/bart-large-cnn", + "questionAnswering": "deepset/roberta-base-squad2", + "tableQuestionAnswering": "google/tapas-base-finetuned-wtq", + "textClassification": "distilbert-base-uncased-finetuned-sst-2-english", + "textGeneration": "gpt2", + "textGenerationStream": "google/flan-t5-xxl", + "tokenClassification": "dbmdz/bert-large-cased-finetuned-conll03-english", + "translation": "t5-base", + "zeroShotClassification": "facebook/bart-large-mnli", + "sentenceSimilarity": "sentence-transformers/paraphrase-xlm-r-multilingual-v1" + }, + "audio": { + "automaticSpeechRecognition": "facebook/wav2vec2-large-960h-lv60-self", + "audioClassification": "superb/hubert-large-superb-er", + "textToSpeech": "espnet/kan-bayashi_ljspeech_vits", + "audioToAudio": "speechbrain/sepformer-wham" + }, + "computerVision": { + "imageClassification": "google/vit-base-patch16-224", + "objectDetection": "facebook/detr-resnet-50", + "imageSegmentation": "facebook/detr-resnet-50-panoptic", + "imageToText": "nlpconnect/vit-gpt2-image-captioning", + "textToImage": "stabilityai/stable-diffusion-2", + "imageToImage": "lllyasviel/sd-controlnet-depth", + "zeroShotImageClassification": "openai/clip-vit-large-patch14-336" + }, + "multimodal": { + "featureExtraction": "sentence-transformers/distilbert-base-nli-mean-tokens", + "visualQuestionAnswering": "dandelin/vilt-b32-finetuned-vqa", + "documentQuestionAnswering": "impira/layoutlm-document-qa" + }, + "tabular": { + "tabularRegression": "scikit-learn/Fish-Weight", + "tabularClassification": "vvmnnnkv/wine-quality" + } + }, + "DEFAULT_VOICE": "onyx" +} diff --git a/lib/huggingface/connector.js b/lib/huggingface/connector.js index 451ec31..67ec55b 100644 --- a/lib/huggingface/connector.js +++ b/lib/huggingface/connector.js @@ -1,478 +1,105 @@ 'use strict'; const { HfInference } = require('@huggingface/inference'); -const { callAPI } = require('../common.js'); -const HUGGINGFACE_TOKEN = process.env.HUGGINGFACE_TOKEN; - -const hf = new HfInference(HUGGINGFACE_TOKEN); - -// You can also omit "model" to use the recommended model for the task - -// constructor() { - -// } - -//......Natural Language Processing -// inputs = '[MASK] world!' - -const FillMask = async (inputs, model = 'bert-base-uncased') => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.fillMask', args); - return res; -}; - -/* - inputs = `The tower is 324 metres (1,063 ft) tall, about the same height - as an 81-storey building, and the tallest structure in Paris. - Its base is square, measuring 125 metres (410 ft) on each side. - During its construction, the Eiffel Tower surpassed the Washington - Monument to become the tallest`, - model = 'facebook/bart-large-cnn' -*/ - -const Summarization = async ( - inputs, - parameters = { max_length: 100 }, - model = 'facebook/bart-large-cnn', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.summarization', args); - return res; -}; +const utils = require('./utils'); +const { DEFAULT_MODELS } = require('./config.json'); + +const { tokens, custom } = utils; + +class Chat { + //temperature = 0.7, topP = 1, frequencyPenalty = 0 + //presencePenalty = 0, stop = ["\n", ""] + constructor({ + apiKey, + system, + model = DEFAULT_MODELS.completions, + tools, + maxTokens = 1000, + // maxPrice = 0.1, + }) { + this.hf = new HfInference(apiKey); + this.system = system; + this.model = model; + this.tools = tools; + this.maxTokens = maxTokens; + // this.maxPrice = maxPrice; + + this.messages = []; + this.tokens = 0; + this.price = 0; + + // throw new Error(`Max ${maxTokens} tokens exceeded`); + } -/* - inputs = { - question: 'What is the capital of France?', - context: 'The capital of France is Paris.' - }, -*/ + async message({ text }) { + const tokenCount = await tokens.count({ text, model: this.model }); + const { maxTokens, model } = this; -const QuestionAnswering = async ( - inputs, - model = 'deepset/roberta-base-squad2', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.questionAnswering', args); - return res; -}; - -/* - inputs = { - query: 'How many stars does the transformers repository have?', - table: { - Repository: ['Transformers', 'Datasets', 'Tokenizers'], - Stars: ['36542', '4512', '3934'], - Contributors: ['651', '77', '34'], - 'Programming language': ['Python', 'Python', 'Rust, Python and NodeJS'] + const increaseMaxTokens = tokens + tokenCount > maxTokens; + if (increaseMaxTokens) { + throw new Error(`Max ${this.maxTokens} tokens exceeded`); } - }, -*/ - -const TableQuestionAnswering = async ( - inputs, - model = 'google/tapas-base-finetuned-wtq', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.tableQuestionAnswering', args); - return res; -}; - -// inputs = 'I like you. I love you.' - -const TextClassification = async ( - inputs, - model = 'distilbert-base-uncased-finetuned-sst-2-english', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.textClassification', args); - return res; -}; - -// inputs = 'The answer to the universe is' - -const TextGeneration = async (inputs, model = 'gpt2') => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.textGeneration', args); - return res; -}; -// inputs = 'repeat "one two three four"' -// parameters = { max_new_tokens: 250 } - -const TextGenerationStream = async ( - inputs, - parameters = {}, - model = 'google/flan-t5-xxl', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.textGenerationStream', args); - return res; -}; -// inputs = 'My name is Sarah Jessica Parker but you can call me Jessica' - -const TokenClassification = async ( - inputs, - model = 'dbmdz/bert-large-cased-finetuned-conll03-english', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.tokenClassification', args); - return res; -}; + const res = await custom(this.hf).generate({ + text, + model, + messages: this.messages, + system: this.system, + tools: this.tools, + }); -// inputs = 'My name is Wolfgang and I live in Amsterdam', -// parameters = {"src_lang": "en_XX", "tgt_lang": "fr_XX"} + if (res.error) return res.error.message; -const Translation = async (inputs, parameters = {}, model = 't5-base') => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.translation', args); - return res; -}; -/* - inputs = [ - 'Hi, I recently bought a device from your company but it is not working' + - ' as advertised and I would like to get reimbursed!' - ], - parameters = { candidate_labels: ['refund', 'legal', 'faq'] } -*/ -const ZeroShotClassification = async ( - inputs, - parameters = {}, - model = 'facebook/bart-large-mnli', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.zeroShotClassification', args); - return res; -}; + this.messages = res.messages; + this.tokens += res.usage.total_tokens; + this.price += res.usage.total_price; -/* - inputs = { - source_sentence: 'That is a happy person', - sentences: [ - 'That is a happy dog', - 'That is a very happy person', - 'Today is a sunny day' - ] + return res.message; } -*/ - -const SentenceSimilarity = async ( - inputs, - model = 'sentence-transformers/paraphrase-xlm-r-multilingual-v1', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.sentenceSimilarity', args); - return res; -}; - -//.........Audio -// data = readFileSync('test/sample1.flac') - -const AutomaticSpeechRecognition = async ( - data, - model = 'facebook/wav2vec2-large-960h-lv60-self', -) => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.automaticSpeechRecognition', args); - return res; -}; - -// data = readFileSync('test/sample1.flac') - -const AudioClassification = async ( - data, - model = 'superb/hubert-large-superb-er', -) => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.audioClassification', args); - return res; -}; - -// inputs = 'Hello world!' - -const TextToSpeech = async ( - inputs, - model = 'espnet/kan-bayashi_ljspeech_vits', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.textToSpeech', args); - return res; -}; -/* - data = readFileSync('test/sample1.flac') - */ -const AudioToAudio = async (data, model = 'speechbrain/sepformer-wham') => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.audioToAudio', args); - return res; -}; - -//........Computer Vision -// data = readFileSync('test/cheetah.png') - -const ImageClassification = async ( - data, - model = 'google/vit-base-patch16-224', -) => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.imageClassification', args); - return res; -}; -/* - data = readFileSync('test/cats.png') - */ -const ObjectDetection = async (data, model = 'facebook/detr-resnet-50') => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.objectDetection', args); - return res; -}; - -// data = readFileSync('test/cats.png') - -const ImageSegmentation = async ( - data, - model = 'facebook/detr-resnet-50-panoptic', -) => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.imageSegmentation', args); - return res; -}; - -// data = await (await fetch('https://picsum.photos/300/300')).blob() - -const ImageToText = async ( - data, - model = 'nlpconnect/vit-gpt2-image-captioning', -) => { - const args = { data, model }; - const res = await callAPI(hf, 'hf.imageToText', args); - return res; -}; - -/* - inputs = 'award winning high resolution photo of a giant' + - ' tortoise/((ladybird)) hybrid, [trending on artstation]', - parameters = {negative_prompt: 'blurry'}, -*/ - -const TextToImage = async ( - inputs, - parameters = {}, - model = 'stabilityai/stable-diffusion-2', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.textToImage', args); - return res; -}; - -/* - inputs = new Blob([readFileSync("test/stormtrooper_depth.png")]), - parameters = {prompt: "elmo's lecture"}, -*/ - -const ImageToImage = async ( - inputs, - parameters = {}, - model = 'lllyasviel/sd-controlnet-depth', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.imageToImage', args); - return res; -}; - -/* - inputs = { image: await (await fetch('https://placekitten.com/300/300')).blob() }, - parameters = { candidate_labels: ['cat', 'dog'] }, -*/ - -const ZeroShotImageClassification = async ( - inputs, - parameters = {}, - model = 'openai/clip-vit-large-patch14-336', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.zeroShotImageClassification', args); - return res; -}; - -//......Multimodal -// inputs = "That is a happy person", - -const FeatureExtraction = async ( - inputs, - model = 'sentence-transformers/distilbert-base-nli-mean-tokens', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.featureExtraction', args); - return res; -}; - -/* - inputs = { - question: 'How many cats are lying down?', - image: await (await fetch('https://placekitten.com/300/300')).blob() - }, -*/ - -const VisualQuestionAnswering = async ( - inputs, - model = 'dandelin/vilt-b32-finetuned-vqa', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.visualQuestionAnswering', args); - return res; -}; - -/* - inputs = { - question: 'Invoice number?', - image: await (await fetch('https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png')).blob(), - }, -*/ -const DocumentQuestionAnswering = async ( - inputs, - model = 'impira/layoutlm-document-qa', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.documentQuestionAnswering', args); - return res; -}; - -//.....Tabular -/* - inputs = { - data: { - "Height": ["11.52", "12.48", "12.3778"], - "Length1": ["23.2", "24", "23.9"], - "Length2": ["25.4", "26.3", "26.5"], - "Length3": ["30", "31.2", "31.1"], - "Species": ["Bream", "Bream", "Bream"], - "Width": ["4.02", "4.3056", "4.6961"] - }, - }, -*/ - -const TabularRegression = async ( - inputs, - model = 'scikit-learn/Fish-Weight', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.tabularRegression', args); - return res; -}; - -/* - inputs = { - data: { - "fixed_acidity": ["7.4", "7.8", "10.3"], - "volatile_acidity": ["0.7", "0.88", "0.32"], - "citric_acid": ["0", "0", "0.45"], - "residual_sugar": ["1.9", "2.6", "6.4"], - "chlorides": ["0.076", "0.098", "0.073"], - "free_sulfur_dioxide": ["11", "25", "5"], - "total_sulfur_dioxide": ["34", "67", "13"], - "density": ["0.9978", "0.9968", "0.9976"], - "pH": ["3.51", "3.2", "3.23"], - "sulphates": ["0.56", "0.68", "0.82"], - "alcohol": ["9.4", "9.8", "12.6"] - }, - }, -*/ - -const TabularClassification = async ( - inputs, - model = 'vvmnnnkv/wine-quality', -) => { - const args = { inputs, model }; - const res = await callAPI(hf, 'hf.tabularClassification', args); - return res; -}; - -//........Custom -/* - inputs = "hello world", - parameters = { - custom_param: 'some magic', + /* + "text" argument - in case we do conversion on front end + */ + async voiceMessage() { + throw new Error('Not Implemented'); } -*/ -const CustomCall = async ( - inputs, - parameters = {}, - model = 'my-custom-model', -) => { - const args = { inputs, parameters, model }; - const res = await callAPI(hf, 'hf.request', args); - return res; -}; - -/* - inputs = "hello world", - parameters = { - custom_param: 'some magic', + async voiceAnswer() { + throw new Error('Not Implemented'); } -*/ - -const CustomCallStreaming = async ( - inputs, - parameters = {}, - model = 'my-custom-model', -) => { - const args = { inputs, parameters, model }; - return this._makeApiCall('streamingRequest', args); -}; - -/* - inputs = 'The answer to the universe is', - endpoint = 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2' -*/ - -const CustomInferenceEndpoint = async (inputs, endpoint) => { - const args = { inputs }; - const hfEndpoint = hf.endpoint(endpoint); - try { - const res = await hfEndpoint.textGeneration(args); - return res; - } catch (err) { - throw new Error('Error occured while triggering "textGeneration" method', { - cause: err, - }); +} + +class Assistant { + constructor({ + assistant_id, + thread_id, + /* + model = DEFAULT_MODELS.completions, + maxTokens = 1000, + maxPrice = 0.1, + */ + }) { + this.id = assistant_id; + this.thread_id = thread_id; + // this.model = model; + // this.maxTokens = maxTokens; + // this.maxPrice = maxPrice; + + this.messages = []; + // this.tokens = 0; + // this.price = 0; } -}; -module.exports = { - FillMask, - Summarization, - QuestionAnswering, - TableQuestionAnswering, - TextClassification, - TextGeneration, - TextGenerationStream, - TokenClassification, - Translation, - ZeroShotClassification, - SentenceSimilarity, - - AutomaticSpeechRecognition, - AudioClassification, - TextToSpeech, - AudioToAudio, - - ImageClassification, - ObjectDetection, - ImageSegmentation, - ImageToText, - TextToImage, - ImageToImage, - ZeroShotImageClassification, - FeatureExtraction, + // message({ text }) {} +} - VisualQuestionAnswering, - DocumentQuestionAnswering, - TabularRegression, - TabularClassification, - CustomCall, - CustomCallStreaming, - CustomInferenceEndpoint, +module.exports = { + Chat, + Assistant, }; + +// Class chat +// chat.message({text}) => {message, messages, usages} => string +// chat.voiceMessage({inputFilePath, outputFilePath, voice}) => +// {inputText, outputText, outputFilePath} diff --git a/lib/huggingface/utils/audio.js b/lib/huggingface/utils/audio.js new file mode 100644 index 0000000..34d0597 --- /dev/null +++ b/lib/huggingface/utils/audio.js @@ -0,0 +1,50 @@ +'use strict'; + +const { callAPI } = require('../../common.js'); +const { DEFAULT_MODELS } = require('../config.json'); + +const defaultModels = DEFAULT_MODELS.audio; + +//......Audio....... + +const audio = (hf) => ({ + + // data = readFileSync('test/sample1.flac') + async automaticSpeechRecognition( + data, + model = defaultModels.automaticSpeechRecognition, + ) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.automaticSpeechRecognition', args); + return res; + }, + + // data = readFileSync('test/sample1.flac') + async audioClassification( + data, + model = defaultModels.audioClassification, + ) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.audioClassification', args); + return res; + }, + + // inputs = 'Hello world!' + async textToSpeech( + inputs, + model = defaultModels.textToSpeech, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.textToSpeech', args); + return res; + }, + + // data = readFileSync('test/sample1.flac') + async audioToAudio(data, model = defaultModels.audioToAudio) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.audioToAudio', args); + return res; + } +}); + +module.exports = { audio }; diff --git a/lib/huggingface/utils/computerVision.js b/lib/huggingface/utils/computerVision.js new file mode 100644 index 0000000..63655ad --- /dev/null +++ b/lib/huggingface/utils/computerVision.js @@ -0,0 +1,94 @@ +'use strict'; + +const { callAPI } = require('../../common.js'); +const { DEFAULT_MODELS } = require('../config.json'); + +const defaultModels = DEFAULT_MODELS.computerVision; + +//......ComputerVision....... + +const computerVision = (hf) => ({ + // data = readFileSync('test/cheetah.png') + async imageClassification( + data, + model = defaultModels.imageClassification, + ) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.imageClassification', args); + return res; + }, + + /* + data = readFileSync('test/cats.png') + */ + async objectDetection(data, model = defaultModels.objectDetection) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.objectDetection', args); + return res; + }, + + // data = readFileSync('test/cats.png') + async imageSegmentation( + data, + model = defaultModels.imageSegmentation, + ) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.imageSegmentation', args); + return res; + }, + + // data = await (await fetch('https://picsum.photos/300/300')).blob() + async imageToText( + data, + model = defaultModels.imageToText, + ) { + const args = { data, model }; + const res = await callAPI(hf, 'hf.imageToText', args); + return res; + }, + + /* + inputs = 'award winning high resolution photo of a giant' + + ' tortoise/((ladybird)) hybrid, [trending on artstation]', + parameters = {negative_prompt: 'blurry'}, + */ + async textToImage( + inputs, + parameters = {}, + model = defaultModels.textToImage, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.textToImage', args); + return res; + }, + + /* + inputs = new Blob([readFileSync("test/stormtrooper_depth.png")]), + parameters = {prompt: "elmo's lecture"}, + */ + async imageToImage( + inputs, + parameters = {}, + model = defaultModels.imageToImage, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.imageToImage', args); + return res; + }, + + /* + inputs = { image: await (await fetch('https://placekitten.com/300/300')).blob() }, + parameters = { candidate_labels: ['cat', 'dog'] }, + */ + async zeroShotImageClassification( + inputs, + parameters = {}, + model = defaultModels.zeroShotImageClassification, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.zeroShotImageClassification', args); + return res; + }, +}); + +module.exports = { computerVision }; diff --git a/lib/huggingface/utils/custom.js b/lib/huggingface/utils/custom.js new file mode 100644 index 0000000..95dc5d5 --- /dev/null +++ b/lib/huggingface/utils/custom.js @@ -0,0 +1,59 @@ +'use strict'; + +const { callAPI } = require('../../common.js'); + +//......Custom....... + +const custom = (hf) => ({ + /* + inputs = "hello world", + parameters = { + custom_param: 'some magic', + } + */ + async customCall( + inputs, + parameters = {}, + model, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.request', args); + return res; + }, + + /* + inputs = "hello world", + parameters = { + custom_param: 'some magic', + } + */ + async customCallStreaming( + inputs, + parameters = {}, + model, + ) { + const args = { inputs, parameters, model }; + return callAPI(hf, 'hf.streamingRequest', args); + }, + + /* + inputs = 'The answer to the universe is', + endpoint = 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2' + */ + async customInferenceEndpoint(inputs, endpoint) { + const args = { inputs }; + const hfEndpoint = hf.endpoint(endpoint); + try { + const res = await hfEndpoint.textGeneration(args); + return res; + } catch (err) { + throw new Error( + 'Error occured while triggering "textGeneration" method', + { + cause: err, + }); + } + } +}); + +module.exports = { custom }; diff --git a/lib/huggingface/utils/index.js b/lib/huggingface/utils/index.js new file mode 100644 index 0000000..83452a7 --- /dev/null +++ b/lib/huggingface/utils/index.js @@ -0,0 +1,10 @@ +'use strict'; + +module.exports = { + ...require('./audio'), + ...require('./custom'), + ...require('./computerVision'), + ...require('./language'), + ...require('./multimodal'), + ...require('./tabular'), +}; diff --git a/lib/huggingface/utils/language.js b/lib/huggingface/utils/language.js new file mode 100644 index 0000000..d6340dc --- /dev/null +++ b/lib/huggingface/utils/language.js @@ -0,0 +1,158 @@ +'use strict'; + +const { callAPI } = require('../../common.js'); +const { DEFAULT_MODELS } = require('../config.json'); + +const defaultModels = DEFAULT_MODELS.language; + +//......Natural Language Processing (language)....... + +const language = (hf) => ({ + // inputs = '[MASK] world!' + async fillMask(inputs, model = defaultModels.fillMask) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.fillMask', args); + return res; + }, + + /* + inputs = `The tower is 324 metres (1,063 ft) tall, about the same height + as an 81-storey building, and the tallest structure in Paris. + Its base is square, measuring 125 metres (410 ft) on each side. + During its construction, the Eiffel Tower surpassed the Washington + Monument to become the tallest`, + model = 'facebook/bart-large-cnn' + */ + async summarization( + inputs, + parameters = { max_length: 100 }, + model = defaultModels.summarization, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.summarization', args); + return res; + }, + + /* + inputs = { + question: 'What is the capital of France?', + context: 'The capital of France is Paris.' + }, + */ + async questionAnswering( + inputs, + model = defaultModels.questionAnswering, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.questionAnswering', args); + return res; + }, + + /* + inputs = { + query: 'How many stars does the transformers repository have?', + table: { + Repository: ['Transformers', 'Datasets', 'Tokenizers'], + Stars: ['36542', '4512', '3934'], + Contributors: ['651', '77', '34'], + 'Programming language': ['Python', 'Python', 'Rust, Python and NodeJS'] + } + }, + */ + async tableQuestionAnswering( + inputs, + model = defaultModels.tableQuestionAnswering, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.tableQuestionAnswering', args); + return res; + }, + + // inputs = 'I like you. I love you.' + async textClassification( + inputs, + model = defaultModels.textClassification, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.textClassification', args); + return res; + }, + + // inputs = 'The answer to the universe is' + async textGeneration(inputs, model = defaultModels.textGeneration) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.textGeneration', args); + return res; + }, + + // inputs = 'repeat "one two three four"' + // parameters = { max_new_tokens: 250 } + async textGenerationStream( + inputs, + parameters = {}, + model = defaultModels.textGenerationStream, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.textGenerationStream', args); + return res; + }, + + // inputs = 'My name is Sarah Jessica Parker but you can call me Jessica' + async tokenClassification( + inputs, + model = defaultModels.tokenClassification, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.tokenClassification', args); + return res; + }, + + // inputs = 'My name is Wolfgang and I live in Amsterdam', + // parameters = {"src_lang": "en_XX", "tgt_lang": "fr_XX"} + async translation( + inputs, parameters = {}, model = defaultModels.translation + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.translation', args); + return res; + }, + + /* + inputs = [ + 'Hi, I recently bought a device from your company but it is not working' + + ' as advertised and I would like to get reimbursed!' + ], + parameters = { candidate_labels: ['refund', 'legal', 'faq'] } + */ + async zeroShotClassification( + inputs, + parameters = {}, + model = defaultModels.zeroShotClassification, + ) { + const args = { inputs, parameters, model }; + const res = await callAPI(hf, 'hf.zeroShotClassification', args); + return res; + }, + + /* + inputs = { + source_sentence: 'That is a happy person', + sentences: [ + 'That is a happy dog', + 'That is a very happy person', + 'Today is a sunny day' + ] + } + */ + + async sentenceSimilarity( + inputs, + model = defaultModels.sentenceSimilarity, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.sentenceSimilarity', args); + return res; + } +}); + +module.exports = { language }; diff --git a/lib/huggingface/utils/multimodal.js b/lib/huggingface/utils/multimodal.js new file mode 100644 index 0000000..cd94e86 --- /dev/null +++ b/lib/huggingface/utils/multimodal.js @@ -0,0 +1,53 @@ +'use strict'; + +const { callAPI } = require('../../common.js'); +const { DEFAULT_MODELS } = require('../config.json'); + +const defaultModels = DEFAULT_MODELS.multimodal; + +//......Multimodal....... + +const multimodal = (hf) => ({ + + // inputs = "That is a happy person", + async featureExtraction( + inputs, + model = defaultModels.featureExtraction, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.featureExtraction', args); + return res; + }, + + /* + inputs = { + question: 'How many cats are lying down?', + image: await (await fetch('https://placekitten.com/300/300')).blob() + }, + */ + async visualQuestionAnswering( + inputs, + model = defaultModels.visualQuestionAnswering, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.visualQuestionAnswering', args); + return res; + }, + + /* + inputs = { + question: 'Invoice number?', + image: await (await fetch('https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png')).blob(), + }, + */ + async documentQuestionAnswering( + inputs, + model = defaultModels.documentQuestionAnswering, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.documentQuestionAnswering', args); + return res; + } +}); + +module.exports = { multimodal }; diff --git a/lib/huggingface/utils/tabular.js b/lib/huggingface/utils/tabular.js new file mode 100644 index 0000000..a4f13ef --- /dev/null +++ b/lib/huggingface/utils/tabular.js @@ -0,0 +1,59 @@ +'use strict'; + +const { callAPI } = require('../../common.js'); +const { DEFAULT_MODELS } = require('../config.json'); + +const defaultModels = DEFAULT_MODELS.tabular; + +//......Tabular....... + +const tabular = (hf) => ({ + /* + inputs = { + data: { + "Height": ["11.52", "12.48", "12.3778"], + "Length1": ["23.2", "24", "23.9"], + "Length2": ["25.4", "26.3", "26.5"], + "Length3": ["30", "31.2", "31.1"], + "Species": ["Bream", "Bream", "Bream"], + "Width": ["4.02", "4.3056", "4.6961"] + }, + }, + */ + async tabularRegression( + inputs, + model = defaultModels.tabularRegression, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.tabularRegression', args); + return res; + }, + + /* + inputs = { + data: { + "fixed_acidity": ["7.4", "7.8", "10.3"], + "volatile_acidity": ["0.7", "0.88", "0.32"], + "citric_acid": ["0", "0", "0.45"], + "residual_sugar": ["1.9", "2.6", "6.4"], + "chlorides": ["0.076", "0.098", "0.073"], + "free_sulfur_dioxide": ["11", "25", "5"], + "total_sulfur_dioxide": ["34", "67", "13"], + "density": ["0.9978", "0.9968", "0.9976"], + "pH": ["3.51", "3.2", "3.23"], + "sulphates": ["0.56", "0.68", "0.82"], + "alcohol": ["9.4", "9.8", "12.6"] + }, + }, + */ + async tabularClassification( + inputs, + model = defaultModels.tabularClassification, + ) { + const args = { inputs, model }; + const res = await callAPI(hf, 'hf.tabularClassification', args); + return res; + } +}); + +module.exports = { tabular }; diff --git a/test/huggingface.js b/test/huggingface.js deleted file mode 100644 index 921ecd7..0000000 --- a/test/huggingface.js +++ /dev/null @@ -1,439 +0,0 @@ -'use strict'; - -const test = require('node:test'); -const assert = require('node:assert'); -const path = require('node:path'); -const { readFileSync } = require('node:fs'); -const { huggingface } = require('../lib'); - -const { - FillMask, - Summarization, - QuestionAnswering, - TableQuestionAnswering, - TextClassification, - TextGeneration, - TextGenerationStream, - TokenClassification, - Translation, - ZeroShotClassification, - SentenceSimilarity, - - AutomaticSpeechRecognition, - AudioClassification, - TextToSpeech, - // AudioToAudio, - - ImageClassification, - ObjectDetection, - ImageSegmentation, - ImageToText, - TextToImage, - ImageToImage, - ZeroShotImageClassification, - FeatureExtraction, - - VisualQuestionAnswering, - DocumentQuestionAnswering, - // TabularRegression, - // TabularClassification, - // CustomCall, - // CustomInferenceEndpoint, - CustomCallStreaming, -} = huggingface; - -const FILES = process.cwd() + '/files/huggingface/'; -const AUDIOS = FILES + 'audios'; -const IMAGES = FILES + 'images'; - -const testAudioFile = readFileSync(path.join(AUDIOS, 'speech.mp3')); -const testCatFile = readFileSync(path.join(IMAGES, 'cat.jpg')); -const testSeeFile = readFileSync(path.join(IMAGES, 'see.jpeg')); -const testInvoiceFile = readFileSync(path.join(IMAGES, 'invoice.png')); - -test('HuggingFace Connector', async (t) => { - await t.test('FillMask', async () => { - const masks = await FillMask('[MASK] world!'); - - assert.ok(Array.isArray(masks)); - - for (const mask of masks) { - assert.ok(typeof mask === 'object'); - assert.ok('score' in mask); - assert.ok('sequence' in mask); - assert.ok('token' in mask); - assert.ok('token_str' in mask); - } - }); - - await t.test('Summarization', async () => { - const res = await Summarization( - 'The tower is 324 metres (1,063 ft) tall about the same height as an' + - '81-storey building, and the tallest structure in Paris. Its base is' + - 'square, measuring 125 metres (410 ft) on each side. During its' + - 'construction, the Eiffel Tower surpassed the Washington Monument to' + - 'become the tallest', - ); - - assert.ok(typeof res === 'object'); - assert.ok('summary_text' in res); - }); - - await t.test('QuestionAnswering', async () => { - const res = await QuestionAnswering({ - question: 'What is the capital of France?', - context: 'The capital of France is Paris.', - }); - - assert.ok(typeof res === 'object'); - assert.ok('score' in res); - assert.ok('start' in res); - assert.ok('end' in res); - assert.ok('answer' in res); - assert.strictEqual(res.answer, 'Paris'); - }); - - await t.test('TableQuestionAnswering', async () => { - const res = await TableQuestionAnswering({ - query: 'How many stars does the transformers repository have?', - table: { - Repository: ['Transformers', 'Datasets', 'Tokenizers'], - Stars: ['36542', '4512', '3934'], - Contributors: ['651', '77', '34'], - 'Programming language': ['Python', 'Python', 'Rust, Python and NodeJS'], - }, - }); - - assert.ok(typeof res === 'object'); - assert.ok('answer' in res); - assert.ok('coordinates' in res); - assert.ok('cells' in res); - assert.ok('aggregator' in res); - assert.deepEqual(res.coordinates, [[0, 1]]); - assert.strictEqual(res.aggregator, 'AVERAGE'); - }); - - await t.test('TextClassification', async () => { - const res = await TextClassification('I like you. I love you.'); - - assert.ok(Array.isArray(res)); - for (const item of res) { - assert.ok('label' in item); - assert.ok('score' in item); - } - }); - - await t.test('TextGeneration', async () => { - const res = await TextGeneration('The answer to the universe is'); - - assert.ok(typeof res === 'object'); - assert.ok('generated_text' in res); - }); - - await t.test('TextGenerationStream', async () => { - const res = await TextGenerationStream('repeat "one two three four"', { - max_new_tokens: 250, - }); - - assert.ok(typeof res === 'object'); - }); - - await t.test('TokenClassification', async () => { - const res = await TokenClassification( - 'My name is Sarah Jessica Parker but you can call me Jessica', - ); - - // console.log(res); - assert.ok(Array.isArray(res)); - - for (const item of res) { - assert.ok(typeof item === 'object'); - assert.ok('start' in item); - assert.ok('end' in item); - assert.ok('entity_group' in item); - assert.ok('score' in item); - assert.ok('word' in item); - } - }); - - await t.test('Translation', async () => { - const res = await Translation( - 'My name is Wolfgang and I live in Amsterdam', - { src_lang: 'en_XX', tgt_lang: 'fr_XX' }, - ); - - assert.ok(typeof res === 'object'); - assert.ok('translation_text' in res); - assert.ok(typeof res.translation_text === 'string'); - }); - - await t.test('ZeroShotClassification', async () => { - const res = await ZeroShotClassification( - [ - 'Hi, I recently bought a device from your company but' + - ' it is not working as advertised and I would like to' + - ' get reimbursed!', - ], - { - candidate_labels: ['refund', 'legal', 'faq'], - }, - ); - - assert.ok(Array.isArray(res)); - assert.ok(res.length === 1); - - const [item] = res; - - assert.ok(typeof item === 'object'); - assert.ok('sequence' in item); - assert.ok('labels' in item); - assert.ok('scores' in item); - assert.ok(typeof item.sequence === 'string'); - assert.ok(Array.isArray(item.labels)); - assert.ok(Array.isArray(item.scores)); - }); - - await t.test('SentenceSimilarity', async () => { - const res = await SentenceSimilarity({ - source_sentence: 'That is a happy person', - sentences: [ - 'That is a happy dog', - 'That is a very happy person', - 'Today is a sunny day', - ], - }); - - assert.ok(Array.isArray(res)); - assert.ok(res.every((item) => typeof item === 'number')); - }); - - await t.test('AutomaticSpeechRecognition', async () => { - const res = await AutomaticSpeechRecognition(testAudioFile); - - assert.ok(typeof res === 'object'); - assert.ok('text' in res); - assert.ok(typeof res.text === 'string'); - }); - - await t.test('AudioClassification', async () => { - const res = await AudioClassification(testAudioFile); - - assert.ok(Array.isArray(res)); - - for (const item of res) { - assert.ok(typeof item === 'object'); - assert.ok('label' in item); - assert.ok('score' in item); - } - }); - - await t.test('TextToSpeech', async () => { - const res = await TextToSpeech('Hello world!'); - - assert.ok(res instanceof Blob); - assert.ok(typeof res.size === 'number'); - assert.strictEqual(res.type, 'audio/flac'); - }); - - // TODO: fix test, getting an error "interface not in config.json" - // test.skip('AudioToAudio', async () => { - // const res = await AudioToAudio(testAudioFile); - // - // console.log(res); - // }); - - await t.test('ImageClassification', async () => { - const res = await ImageClassification(testCatFile); - - assert.ok(Array.isArray(res)); - - for (const item of res) { - assert.ok(typeof item === 'object'); - assert.ok('label' in item); - assert.ok('score' in item); - } - }); - - await t.test('ObjectDetection', async () => { - const res = await ObjectDetection(testCatFile); - - assert.ok(Array.isArray(res)); - assert.ok(res.length === 1); - - const [item] = res; - - assert.ok(typeof item === 'object'); - assert.ok('box' in item); - assert.ok('label' in item); - assert.ok('score' in item); - assert.ok(typeof item.box === 'object'); - assert.ok(typeof item.label === 'string'); - assert.ok(typeof item.score === 'number'); - assert.strictEqual(item.label, 'cat'); - - const { box } = item; - - assert.ok(typeof box === 'object'); - assert.ok(typeof box.xmin === 'number'); - assert.ok(typeof box.ymin === 'number'); - assert.ok(typeof box.xmax === 'number'); - assert.ok(typeof box.ymax === 'number'); - }); - - await t.test('ImageSegmentation', async () => { - const res = await ImageSegmentation(testCatFile); - - assert.ok(Array.isArray(res)); - - for (const item of res) { - assert.ok(typeof item === 'object'); - assert.ok('score' in item); - assert.ok('label' in item); - assert.ok('mask' in item); - } - }); - - await t.test('ImageToText', async () => { - const res = await ImageToText(new Blob([testSeeFile])); - - assert.ok(typeof res === 'object'); - assert.ok('generated_text' in res); - assert.ok(typeof res.generated_text === 'string'); - }); - - await t.test('TextToImage', async () => { - const inputs = - 'award winning high resolution photo of a giant tortoise' + - '/((ladybird)) hybrid, [trending on artstation]'; - const res = await TextToImage(inputs, { negative_prompt: 'blurry' }); - - assert.ok(res instanceof Blob); - assert.ok(typeof res.size === 'number'); - assert.ok(res.type === 'image/jpeg'); - }); - - await t.test('ImageToImage', async () => { - const res = await ImageToImage(new Blob([testSeeFile]), { - prompt: 'test picture', - }); - - assert.ok(res instanceof Blob); - assert.ok(typeof res.size === 'number'); - assert.ok(res.type === 'image/jpeg'); - }); - - await t.test('ZeroShotImageClassification', async () => { - const inputs = { image: new Blob([testCatFile]) }; - const res = await ZeroShotImageClassification(inputs, { - candidate_labels: ['cat', 'dog'], - }); - - assert.ok(Array.isArray(res)); - - for (const item of res) { - assert.ok(typeof item === 'object'); - assert.ok('score' in item); - assert.ok('label' in item); - } - }); - - await t.test('FeatureExtraction', async () => { - const res = await FeatureExtraction('That is a happy person'); - - assert.ok(Array.isArray(res)); - assert.ok(res.every((el) => typeof el === 'number')); - }); - - await t.test('VisualQuestionAnswering', async () => { - const inputs = { - question: 'How many cats are lying down?', - image: new Blob([testCatFile]), - }; - const res = await VisualQuestionAnswering(inputs); - - assert.ok(typeof res === 'object'); - assert.ok('score' in res); - assert.ok('answer' in res); - assert.ok(typeof res.score === 'number'); - assert.strictEqual(res.answer, '1'); - }); - - await t.test('DocumentQuestionAnswering', async () => { - const inputs = { - question: 'Invoice number?', - image: new Blob([testInvoiceFile]), - }; - const res = await DocumentQuestionAnswering(inputs); - - assert.ok(typeof res === 'object'); - assert.ok('score' in res && typeof res.score === 'number'); - assert.ok('start' in res && typeof res.start === 'number'); - assert.ok('end' in res && typeof res.end === 'number'); - assert.ok('answer' in res && typeof res.answer === 'string'); - assert.strictEqual(res.answer, 'us-001'); - }); - - // TODO: fix test, timeout - // test.skip('TabularRegression', async () => { - // const inputs = { - // data: { - // Height: ['11.52', '12.48', '12.3778'], - // Length1: ['23.2', '24', '23.9'], - // Length2: ['25.4', '26.3', '26.5'], - // Length3: ['30', '31.2', '31.1'], - // Species: ['Bream', 'Bream', 'Bream'], - // Width: ['4.02', '4.3056', '4.6961'], - // }, - // }; - // const res = await TabularRegression(inputs); - // - // console.log(res); - // }, 60000); - - // TODO: fix test, timeout - // test.skip('TabularClassification', async () => { - // const inputs = { - // data: { - // fixed_acidity: ['7.4', '7.8', '10.3'], - // volatile_acidity: ['0.7', '0.88', '0.32'], - // citric_acid: ['0', '0', '0.45'], - // residual_sugar: ['1.9', '2.6', '6.4'], - // chlorides: ['0.076', '0.098', '0.073'], - // free_sulfur_dioxide: ['11', '25', '5'], - // total_sulfur_dioxide: ['34', '67', '13'], - // density: ['0.9978', '0.9968', '0.9976'], - // pH: ['3.51', '3.2', '3.23'], - // sulphates: ['0.56', '0.68', '0.82'], - // alcohol: ['9.4', '9.8', '12.6'], - // }, - // }; - // const res = await TabularClassification(inputs); - // - // console.log(res); - // }, 60000); - - // TODO: fix test, response is undefined for some reason - // test.skip('CustomCall', async () => { - // const res = await CustomCall('hello world'); - // - // console.log(res); - // }); - - await t.test('CustomCallStreaming', async () => { - const res = await CustomCallStreaming('hello world'); - - assert.ok(typeof res === 'object'); - }); - - // TODO: To test this one we need to have own inference endpoint - // test.skip('CustomInferenceEndpoint', async () => { - // const endpoint = - // 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2'; - // const res = await CustomInferenceEndpoint( - // 'The answer to the universe is', - // endpoint, - // ); - // - // console.log(res); - // }); -}); diff --git a/test/huggingface/audio.js b/test/huggingface/audio.js new file mode 100644 index 0000000..5ce7bfe --- /dev/null +++ b/test/huggingface/audio.js @@ -0,0 +1,66 @@ +'use strict'; + +const { beforeEach, it, describe } = require('node:test'); +const assert = require('node:assert'); +const path = require('node:path'); +const { readFileSync } = require('node:fs'); + +const { huggingface } = require('../../lib'); +const utils = require('../../lib/huggingface/utils'); + +const { Chat } = huggingface; + +const API_KEY = process.env.HUGGINGFACE_API_KEY; + +const { + audio: uAudio, +} = utils; + +const FILES = process.cwd() + '/files/huggingface/'; +const AUDIOS = FILES + 'audios'; + +const testAudioFile = readFileSync(path.join(AUDIOS, 'speech.mp3')); + +describe('audio', () => { + let audio; + beforeEach(() => { + const chat = new Chat({ apiKey: API_KEY }); + audio = uAudio(chat.hf); + }); + + it('automaticSpeechRecognition', async () => { + const res = await audio.automaticSpeechRecognition(testAudioFile); + + console.log(res); + assert.ok(typeof res === 'object'); + assert.ok('text' in res); + assert.ok(typeof res.text === 'string'); + }); + + it('audioClassification', async () => { + const res = await audio.audioClassification(testAudioFile); + + assert.ok(Array.isArray(res)); + + for (const item of res) { + assert.ok(typeof item === 'object'); + assert.ok('label' in item); + assert.ok('score' in item); + } + }); + + it('textToSpeech', async () => { + const res = await audio.textToSpeech('Hello world!'); + + assert.ok(res instanceof Blob); + assert.ok(typeof res.size === 'number'); + assert.strictEqual(res.type, 'audio/flac'); + }); + + // TODO: fix test, getting an error "interface not in config.json" + it.skip('audioToAudio', async () => { + const res = await audio.audioToAudio(testAudioFile); + + console.log(res); + }); +}); diff --git a/test/huggingface/computerVision.js b/test/huggingface/computerVision.js new file mode 100644 index 0000000..50ed7f1 --- /dev/null +++ b/test/huggingface/computerVision.js @@ -0,0 +1,129 @@ +'use strict'; + +const { beforeEach, it, describe } = require('node:test'); +const assert = require('node:assert'); +const path = require('node:path'); +const { readFileSync } = require('node:fs'); + +const { huggingface } = require('../../lib'); +const utils = require('../../lib/huggingface/utils'); + +const { Chat } = huggingface; + +const API_KEY = process.env.HUGGINGFACE_API_KEY; + +const { + computerVision: uComputerVision, +} = utils; + +const FILES = process.cwd() + '/files/huggingface/'; +const IMAGES = FILES + 'images'; + +const testCatFile = readFileSync(path.join(IMAGES, 'cat.jpg')); +const testSeeFile = readFileSync(path.join(IMAGES, 'see.jpeg')); + +describe('computerVision', () => { + let computerVision; + + beforeEach(() => { + const chat = new Chat({ apiKey: API_KEY }); + computerVision = uComputerVision(chat.hf); + }); + + it('imageClassification', async () => { + const res = await computerVision.imageClassification(testCatFile); + + assert.ok(Array.isArray(res)); + + for (const item of res) { + assert.ok(typeof item === 'object'); + assert.ok('label' in item); + assert.ok('score' in item); + } + }); + + it('objectDetection', async () => { + const res = await computerVision.objectDetection(testCatFile); + + assert.ok(Array.isArray(res)); + assert.ok(res.length === 1); + + const [item] = res; + + assert.ok(typeof item === 'object'); + assert.ok('box' in item); + assert.ok('label' in item); + assert.ok('score' in item); + assert.ok(typeof item.box === 'object'); + assert.ok(typeof item.label === 'string'); + assert.ok(typeof item.score === 'number'); + assert.strictEqual(item.label, 'cat'); + + const { box } = item; + + assert.ok(typeof box === 'object'); + assert.ok(typeof box.xmin === 'number'); + assert.ok(typeof box.ymin === 'number'); + assert.ok(typeof box.xmax === 'number'); + assert.ok(typeof box.ymax === 'number'); + }); + + it('imageSegmentation', async () => { + const res = await computerVision.imageSegmentation(testCatFile); + + assert.ok(Array.isArray(res)); + + for (const item of res) { + assert.ok(typeof item === 'object'); + assert.ok('score' in item); + assert.ok('label' in item); + assert.ok('mask' in item); + } + }); + + it('imageToText', async () => { + const res = await computerVision.imageToText(testCatFile); + + assert.ok(typeof res === 'object'); + assert.ok('generated_text' in res); + assert.ok(typeof res.generated_text === 'string'); + }); + + it('textToImage', async () => { + const inputs = + 'award winning high resolution photo of a giant tortoise' + + '/((ladybird)) hybrid, [trending on artstation]'; + const res = await computerVision.textToImage(inputs, { + negative_prompt: 'blurry' + }); + + assert.ok(res instanceof Blob); + assert.ok(typeof res.size === 'number'); + assert.ok(res.type === 'image/jpeg'); + }); + + it('imageToImage', async () => { + const res = await computerVision.imageToImage(new Blob([testSeeFile]), { + prompt: 'test picture', + }); + + assert.ok(res instanceof Blob); + assert.ok(typeof res.size === 'number'); + assert.ok(res.type === 'image/jpeg'); + }); + + it('ZeroShotImageClassification', async () => { + const inputs = { image: new Blob([testCatFile]) }; + const res = await computerVision.zeroShotImageClassification(inputs, { + candidate_labels: ['cat', 'dog'], + }); + + assert.ok(Array.isArray(res)); + + for (const item of res) { + assert.ok(typeof item === 'object'); + assert.ok('score' in item); + assert.ok('label' in item); + } + }); +}); diff --git a/test/huggingface/custom.js b/test/huggingface/custom.js new file mode 100644 index 0000000..82f170a --- /dev/null +++ b/test/huggingface/custom.js @@ -0,0 +1,49 @@ +'use strict'; + +const { beforeEach, it, describe } = require('node:test'); +const assert = require('node:assert'); + +const { huggingface } = require('../../lib'); +const utils = require('../../lib/huggingface/utils'); + +const { Chat } = huggingface; + +const API_KEY = process.env.HUGGINGFACE_API_KEY; + +const { + custom: uCustom, +} = utils; + +describe('custom', () => { + let custom; + + beforeEach(() => { + const chat = new Chat({ apiKey: API_KEY }); + custom = uCustom(chat.hf); + }); + + // TODO: fix test, response is undefined for some reason + it.skip('customCall', async () => { + const res = await custom.customCall('hello world'); + + console.log(res); + }); + + it('customCallStreaming', async () => { + const res = await custom.customCallStreaming('hello world'); + + assert.ok(typeof res === 'object'); + }); + + // TODO: To test this one we need to have own inference endpoint + it.skip('customInferenceEndpoint', async () => { + const endpoint = + 'https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2'; + const res = await custom.customInferenceEndpoint( + 'The answer to the universe is', + endpoint, + ); + + console.log(res); + }); +}); diff --git a/test/huggingface/language.js b/test/huggingface/language.js new file mode 100644 index 0000000..8e8b168 --- /dev/null +++ b/test/huggingface/language.js @@ -0,0 +1,181 @@ +'use strict'; + +const { beforeEach, it, describe } = require('node:test'); +const assert = require('node:assert'); + +const { huggingface } = require('../../lib'); +const utils = require('../../lib/huggingface/utils'); + +const { Chat } = huggingface; + +const API_KEY = process.env.HUGGINGFACE_API_KEY; + +const { + language: uLanguage, +} = utils; + +describe('language', () => { + let language; + + beforeEach(() => { + const chat = new Chat({ apiKey: API_KEY }); + language = uLanguage(chat.hf); + }); + + it('fillMask', async () => { + const masks = await language.fillMask('[MASK] world!'); + + assert.ok(Array.isArray(masks)); + + for (const mask of masks) { + assert.ok(typeof mask === 'object'); + assert.ok('score' in mask); + assert.ok('sequence' in mask); + assert.ok('token' in mask); + assert.ok('token_str' in mask); + } + }); + + it('summarization', async () => { + const res = await language.summarization( + 'The tower is 324 metres (1,063 ft) tall about the same height as an' + + '81-storey building, and the tallest structure in Paris. Its base is' + + 'square, measuring 125 metres (410 ft) on each side. During its' + + 'construction, the Eiffel Tower surpassed the Washington Monument to' + + 'become the tallest', + ); + + assert.ok(typeof res === 'object'); + assert.ok('summary_text' in res); + }); + + it('questionAnswering', async () => { + const res = await language.questionAnswering({ + question: 'What is the capital of France?', + context: 'The capital of France is Paris.', + }); + + assert.ok(typeof res === 'object'); + assert.ok('score' in res); + assert.ok('start' in res); + assert.ok('end' in res); + assert.ok('answer' in res); + assert.strictEqual(res.answer, 'Paris'); + }); + + it('tableQuestionAnswering', async () => { + const res = await language.tableQuestionAnswering({ + query: 'How many stars does the transformers repository have?', + table: { + Repository: ['Transformers', 'Datasets', 'Tokenizers'], + Stars: ['36542', '4512', '3934'], + Contributors: ['651', '77', '34'], + 'Programming language': ['Python', 'Python', 'Rust, Python and NodeJS'], + }, + }); + + assert.ok(typeof res === 'object'); + assert.ok('answer' in res); + assert.ok('coordinates' in res); + assert.ok('cells' in res); + assert.ok('aggregator' in res); + assert.deepEqual(res.coordinates, [[0, 1]]); + assert.strictEqual(res.aggregator, 'AVERAGE'); + }); + + it('textClassification', async () => { + const res = await language.textClassification('I like you. I love you.'); + + assert.ok(Array.isArray(res)); + for (const item of res) { + assert.ok('label' in item); + assert.ok('score' in item); + } + }); + + it('textGeneration', async () => { + const input = 'The answer to the universe is'; + const res = await language.textGeneration(input); + + assert.ok(typeof res === 'object'); + assert.ok('generated_text' in res); + }); + + it('textGenerationStream', async () => { + const input = 'repeat "one two three four"'; + const res = await language.textGenerationStream(input, { + max_new_tokens: 250, + }); + + assert.ok(typeof res === 'object'); + }); + + it('tokenClassification', async () => { + const res = await language.tokenClassification( + 'My name is Sarah Jessica Parker but you can call me Jessica', + ); + + // console.log(res); + assert.ok(Array.isArray(res)); + + for (const item of res) { + assert.ok(typeof item === 'object'); + assert.ok('start' in item); + assert.ok('end' in item); + assert.ok('entity_group' in item); + assert.ok('score' in item); + assert.ok('word' in item); + } + }); + + it('translation', async () => { + const res = await language.translation( + 'My name is Wolfgang and I live in Amsterdam', + { src_lang: 'en_XX', tgt_lang: 'fr_XX' }, + ); + + assert.ok(typeof res === 'object'); + assert.ok('translation_text' in res); + assert.ok(typeof res.translation_text === 'string'); + }); + + it('zeroShotClassification', async () => { + const res = await language.zeroShotClassification( + [ + 'Hi, I recently bought a device from your company but' + + ' it is not working as advertised and I would like to' + + ' get reimbursed!', + ], + { + candidate_labels: ['refund', 'legal', 'faq'], + }, + ); + + assert.ok(Array.isArray(res)); + assert.ok(res.length === 1); + + const [item] = res; + + assert.ok(typeof item === 'object'); + assert.ok('sequence' in item); + assert.ok('labels' in item); + assert.ok('scores' in item); + assert.ok(typeof item.sequence === 'string'); + assert.ok(Array.isArray(item.labels)); + assert.ok(Array.isArray(item.scores)); + }); + + it('sentenceSimilarity', async () => { + const res = await language.sentenceSimilarity({ + source_sentence: 'That is a happy person', + sentences: [ + 'That is a happy dog', + 'That is a very happy person', + 'Today is a sunny day', + ], + }); + + assert.ok(Array.isArray(res)); + assert.ok(res.every((item) => typeof item === 'number')); + }); +}); diff --git a/test/huggingface/multimodal.js b/test/huggingface/multimodal.js new file mode 100644 index 0000000..85212dc --- /dev/null +++ b/test/huggingface/multimodal.js @@ -0,0 +1,68 @@ +'use strict'; + +const { beforeEach, it, describe } = require('node:test'); +const assert = require('node:assert'); +const path = require('node:path'); +const { readFileSync } = require('node:fs'); + +const { huggingface } = require('../../lib'); +const utils = require('../../lib/huggingface/utils'); + +const { Chat } = huggingface; + +const API_KEY = process.env.HUGGINGFACE_API_KEY; + +const { + multimodal: uMultimodal, +} = utils; + +const FILES = process.cwd() + '/files/huggingface/'; +const IMAGES = FILES + 'images'; + +const testCatFile = readFileSync(path.join(IMAGES, 'cat.jpg')); +const testInvoiceFile = readFileSync(path.join(IMAGES, 'invoice.png')); + +describe('multimodal', () => { + let multimodal; + + beforeEach(() => { + const chat = new Chat({ apiKey: API_KEY }); + multimodal = uMultimodal(chat.hf); + }); + + it('featureExtraction', async () => { + const res = await multimodal.featureExtraction('That is a happy person'); + + assert.ok(Array.isArray(res)); + assert.ok(res.every((el) => typeof el === 'number')); + }); + + it('visualQuestionAnswering', async () => { + const inputs = { + question: 'How many cats are lying down?', + image: new Blob([testCatFile]), + }; + const res = await multimodal.visualQuestionAnswering(inputs); + + assert.ok(typeof res === 'object'); + assert.ok('score' in res); + assert.ok('answer' in res); + assert.ok(typeof res.score === 'number'); + assert.strictEqual(res.answer, '1'); + }); + + it('documentQuestionAnswering', async () => { + const inputs = { + question: 'Invoice number?', + image: new Blob([testInvoiceFile]), + }; + const res = await multimodal.documentQuestionAnswering(inputs); + + assert.ok(typeof res === 'object'); + assert.ok('score' in res && typeof res.score === 'number'); + assert.ok('start' in res && typeof res.start === 'number'); + assert.ok('end' in res && typeof res.end === 'number'); + assert.ok('answer' in res && typeof res.answer === 'string'); + assert.strictEqual(res.answer, 'us-001'); + }); +}); diff --git a/test/huggingface/tabular.js b/test/huggingface/tabular.js new file mode 100644 index 0000000..f5110b0 --- /dev/null +++ b/test/huggingface/tabular.js @@ -0,0 +1,62 @@ +'use strict'; + +const { beforeEach, it, describe } = require('node:test'); + +const { huggingface } = require('../../lib'); +const utils = require('../../lib/huggingface/utils'); + +const { Chat } = huggingface; + +const API_KEY = process.env.HUGGINGFACE_API_KEY; + +const { + tabular: uTabular, +} = utils; + +describe('tabular', () => { + let tabular; + + beforeEach(() => { + const chat = new Chat({ apiKey: API_KEY }); + tabular = uTabular(chat.hf); + }); + + // TODO: fix test, timeout + it.skip('TabularRegression', async () => { + const inputs = { + data: { + Height: ['11.52', '12.48', '12.3778'], + Length1: ['23.2', '24', '23.9'], + Length2: ['25.4', '26.3', '26.5'], + Length3: ['30', '31.2', '31.1'], + Species: ['Bream', 'Bream', 'Bream'], + Width: ['4.02', '4.3056', '4.6961'], + }, + }; + const res = await tabular.tabularRegression(inputs); + + console.log(res); + }, 60000); + + // TODO: fix test, timeout + it.skip('TabularClassification', async () => { + const inputs = { + data: { + fixed_acidity: ['7.4', '7.8', '10.3'], + volatile_acidity: ['0.7', '0.88', '0.32'], + citric_acid: ['0', '0', '0.45'], + residual_sugar: ['1.9', '2.6', '6.4'], + chlorides: ['0.076', '0.098', '0.073'], + free_sulfur_dioxide: ['11', '25', '5'], + total_sulfur_dioxide: ['34', '67', '13'], + density: ['0.9978', '0.9968', '0.9976'], + pH: ['3.51', '3.2', '3.23'], + sulphates: ['0.56', '0.68', '0.82'], + alcohol: ['9.4', '9.8', '12.6'], + }, + }; + const res = await tabular.TabularClassification(inputs); + + console.log(res); + }, 60000); +});