-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #179 from yandex-cloud/foundation-models-sdk
[WIP] [DO NOT MERGE] Foundation Models Sdk
- Loading branch information
Showing
15 changed files
with
865 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
export * from './fileSdk'; | ||
export * from '..'; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import { Client } from 'nice-grpc'; | ||
import { embeddingService } from '..'; | ||
import { | ||
EmbeddingsServiceService, | ||
TextEmbeddingRequest, | ||
} from '../generated/yandex/cloud/ai/foundation_models/v1/embedding/embedding_service'; | ||
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types'; | ||
|
||
export type TextEmbeddingProps = Omit<TypeFromProtoc<TextEmbeddingRequest, 'text'>, 'modelUri'> & { | ||
modelId: string; | ||
folderId: string; | ||
}; | ||
|
||
export class EmbeddingSdk { | ||
private embeddingClient: Client<typeof EmbeddingsServiceService, ClientCallArgs>; | ||
|
||
static ENDPOINT = 'llm.api.cloud.yandex.net:443'; | ||
|
||
constructor(session: SessionArg, endpoint = EmbeddingSdk.ENDPOINT) { | ||
this.embeddingClient = session.client(embeddingService.EmbeddingsServiceClient, endpoint); | ||
} | ||
|
||
textEmbedding(params: TextEmbeddingProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
return this.embeddingClient.textEmbedding( | ||
embeddingService.TextEmbeddingRequest.fromPartial({ ...restParams, modelUri }), | ||
args, | ||
); | ||
} | ||
} | ||
|
||
export const initEmbeddingSdk = (session: SessionArg, endpoint = EmbeddingSdk.ENDPOINT) => { | ||
return new EmbeddingSdk(session, endpoint); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import { Client } from 'nice-grpc'; | ||
import { imageGenerationService } from '..'; | ||
import { | ||
ImageGenerationAsyncServiceService, | ||
ImageGenerationRequest, | ||
} from '../generated/yandex/cloud/ai/foundation_models/v1/image_generation/image_generation_service'; | ||
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types'; | ||
|
||
export type GenerateImageProps = Omit< | ||
TypeFromProtoc<ImageGenerationRequest, 'messages'>, | ||
'modelUri' | ||
> & { | ||
modelId: string; | ||
folderId: string; | ||
}; | ||
|
||
export class ImageGenerationSdk { | ||
private imageGenerationClient: Client< | ||
typeof ImageGenerationAsyncServiceService, | ||
ClientCallArgs | ||
>; | ||
|
||
static ENDPOINT = 'llm.api.cloud.yandex.net:443'; | ||
|
||
constructor(session: SessionArg, endpoint = ImageGenerationSdk.ENDPOINT) { | ||
this.imageGenerationClient = session.client( | ||
imageGenerationService.ImageGenerationAsyncServiceClient, | ||
endpoint, | ||
); | ||
} | ||
|
||
generateImage(params: GenerateImageProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `art://${folderId}/${modelId}`; | ||
|
||
return this.imageGenerationClient.generate( | ||
imageGenerationService.ImageGenerationRequest.fromPartial({ ...restParams, modelUri }), | ||
args, | ||
); | ||
} | ||
} | ||
|
||
export const initImageGenerationSdk = ( | ||
session: SessionArg, | ||
endpoint = ImageGenerationSdk.ENDPOINT, | ||
) => { | ||
return new ImageGenerationSdk(session, endpoint); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
export * from './embeddingSdk'; | ||
export * from './imageGenerationSdk'; | ||
export * from './textClassificationSdk'; | ||
export * from './textGenerationSdk'; | ||
export * from '..'; |
73 changes: 73 additions & 0 deletions
73
clients/ai-foundation_models-v1/sdk/textClassificationSdk.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import { Client } from 'nice-grpc'; | ||
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types'; | ||
import { | ||
FewShotTextClassificationRequest, | ||
TextClassificationRequest, | ||
TextClassificationServiceService, | ||
} from '../generated/yandex/cloud/ai/foundation_models/v1/text_classification/text_classification_service'; | ||
import { textClassificationService } from '..'; | ||
|
||
export type TextClassificationProps = Omit< | ||
TypeFromProtoc<TextClassificationRequest, 'text'>, | ||
'modelUri' | ||
> & { | ||
modelId: string; | ||
folderId: string; | ||
}; | ||
|
||
export type FewShotTextClassificationProps = Omit< | ||
TypeFromProtoc<FewShotTextClassificationRequest, 'text'>, | ||
'modelUri' | ||
> & { | ||
modelId: string; | ||
folderId: string; | ||
}; | ||
|
||
export class TextClassificationSdk { | ||
private textClassificationClient: Client< | ||
typeof TextClassificationServiceService, | ||
ClientCallArgs | ||
>; | ||
|
||
static ENDPOINT = 'llm.api.cloud.yandex.net:443'; | ||
|
||
constructor(session: SessionArg, endpoint = TextClassificationSdk.ENDPOINT) { | ||
this.textClassificationClient = session.client( | ||
textClassificationService.TextClassificationServiceClient, | ||
endpoint, | ||
); | ||
} | ||
|
||
classifyText(params: TextClassificationProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
return this.textClassificationClient.classify( | ||
textClassificationService.TextClassificationRequest.fromPartial({ | ||
...restParams, | ||
modelUri, | ||
}), | ||
args, | ||
); | ||
} | ||
|
||
classifyTextFewShort(params: FewShotTextClassificationProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
return this.textClassificationClient.fewShotClassify( | ||
textClassificationService.FewShotTextClassificationRequest.fromPartial({ | ||
...restParams, | ||
modelUri, | ||
}), | ||
args, | ||
); | ||
} | ||
} | ||
|
||
export const initTextClassificationSdk = ( | ||
session: SessionArg, | ||
endpoint = TextClassificationSdk.ENDPOINT, | ||
) => { | ||
return new TextClassificationSdk(session, endpoint); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import { Client } from 'nice-grpc'; | ||
import { textGenerationService } from '..'; | ||
|
||
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types'; | ||
import { | ||
CompletionRequest, | ||
TextGenerationAsyncServiceService, | ||
TextGenerationServiceService, | ||
TokenizeRequest, | ||
TokenizerServiceService, | ||
} from '../generated/yandex/cloud/ai/foundation_models/v1/text_generation/text_generation_service'; | ||
|
||
export type CompletionProps = Omit<TypeFromProtoc<CompletionRequest, 'messages'>, 'modelUri'> & { | ||
modelId: string; | ||
folderId: string; | ||
}; | ||
|
||
export type TokenizeProps = Omit<TypeFromProtoc<TokenizeRequest, 'text'>, 'modelUri'> & { | ||
modelId: string; | ||
folderId: string; | ||
}; | ||
|
||
export class TextGenerationSdk { | ||
private textGenerationClient: Client<typeof TextGenerationServiceService, ClientCallArgs>; | ||
private tokenizerClient: Client<typeof TokenizerServiceService, ClientCallArgs>; | ||
private textGenerationAsyncClient: Client< | ||
typeof TextGenerationAsyncServiceService, | ||
ClientCallArgs | ||
>; | ||
|
||
static ENDPOINT = 'llm.api.cloud.yandex.net:443'; | ||
|
||
constructor(session: SessionArg, endpoint = TextGenerationSdk.ENDPOINT) { | ||
this.textGenerationClient = session.client( | ||
textGenerationService.TextGenerationServiceClient, | ||
endpoint, | ||
); | ||
|
||
this.tokenizerClient = session.client( | ||
textGenerationService.TokenizerServiceClient, | ||
endpoint, | ||
); | ||
|
||
this.textGenerationAsyncClient = session.client( | ||
textGenerationService.TextGenerationAsyncServiceClient, | ||
endpoint, | ||
); | ||
} | ||
|
||
tokenize(params: TokenizeProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
return this.tokenizerClient.tokenize( | ||
textGenerationService.TokenizeRequest.fromPartial({ ...restParams, modelUri }), | ||
args, | ||
); | ||
} | ||
|
||
tokenizeCompletion(params: CompletionProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
return this.tokenizerClient.tokenizeCompletion( | ||
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }), | ||
args, | ||
); | ||
} | ||
|
||
completion(params: CompletionProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
return this.textGenerationClient.completion( | ||
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }), | ||
args, | ||
); | ||
} | ||
|
||
completionAsOperation(params: CompletionProps, args?: ClientCallArgs) { | ||
const { modelId, folderId, ...restParams } = params; | ||
const modelUri = `gpt://${folderId}/${modelId}`; | ||
|
||
const operationP = this.textGenerationAsyncClient.completion( | ||
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }), | ||
args, | ||
); | ||
|
||
return operationP; | ||
} | ||
} | ||
|
||
export const initTextGenerationSdk = ( | ||
session: SessionArg, | ||
endpoint = TextGenerationSdk.ENDPOINT, | ||
) => { | ||
return new TextGenerationSdk(session, endpoint); | ||
}; |
Oops, something went wrong.