Skip to content

Commit

Permalink
Merge pull request #179 from yandex-cloud/foundation-models-sdk
Browse files Browse the repository at this point in the history
[WIP] [DO NOT MERGE] Foundation Models Sdk
  • Loading branch information
GermanVor authored Jan 1, 2025
2 parents f982538 + 54bb7c7 commit 2d8b293
Show file tree
Hide file tree
Showing 15 changed files with 865 additions and 2 deletions.
1 change: 1 addition & 0 deletions clients/ai-assistants-v1/sdk/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ export * from './threadSdk';
export * from './searchIndexFileSdk';
export * from './searchIndexSdk';
export * from './userSdk';
export * from '..';
1 change: 1 addition & 0 deletions clients/ai-files-v1/sdk/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export * from './fileSdk';
export * from '..';
36 changes: 36 additions & 0 deletions clients/ai-foundation_models-v1/sdk/embeddingSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { Client } from 'nice-grpc';
import { embeddingService } from '..';
import {
EmbeddingsServiceService,
TextEmbeddingRequest,
} from '../generated/yandex/cloud/ai/foundation_models/v1/embedding/embedding_service';
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';

export type TextEmbeddingProps = Omit<TypeFromProtoc<TextEmbeddingRequest, 'text'>, 'modelUri'> & {
modelId: string;
folderId: string;
};

export class EmbeddingSdk {
private embeddingClient: Client<typeof EmbeddingsServiceService, ClientCallArgs>;

static ENDPOINT = 'llm.api.cloud.yandex.net:443';

constructor(session: SessionArg, endpoint = EmbeddingSdk.ENDPOINT) {
this.embeddingClient = session.client(embeddingService.EmbeddingsServiceClient, endpoint);
}

textEmbedding(params: TextEmbeddingProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.embeddingClient.textEmbedding(
embeddingService.TextEmbeddingRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}
}

export const initEmbeddingSdk = (session: SessionArg, endpoint = EmbeddingSdk.ENDPOINT) => {
return new EmbeddingSdk(session, endpoint);
};
48 changes: 48 additions & 0 deletions clients/ai-foundation_models-v1/sdk/imageGenerationSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { Client } from 'nice-grpc';
import { imageGenerationService } from '..';
import {
ImageGenerationAsyncServiceService,
ImageGenerationRequest,
} from '../generated/yandex/cloud/ai/foundation_models/v1/image_generation/image_generation_service';
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';

export type GenerateImageProps = Omit<
TypeFromProtoc<ImageGenerationRequest, 'messages'>,
'modelUri'
> & {
modelId: string;
folderId: string;
};

export class ImageGenerationSdk {
private imageGenerationClient: Client<
typeof ImageGenerationAsyncServiceService,
ClientCallArgs
>;

static ENDPOINT = 'llm.api.cloud.yandex.net:443';

constructor(session: SessionArg, endpoint = ImageGenerationSdk.ENDPOINT) {
this.imageGenerationClient = session.client(
imageGenerationService.ImageGenerationAsyncServiceClient,
endpoint,
);
}

generateImage(params: GenerateImageProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `art://${folderId}/${modelId}`;

return this.imageGenerationClient.generate(
imageGenerationService.ImageGenerationRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}
}

export const initImageGenerationSdk = (
session: SessionArg,
endpoint = ImageGenerationSdk.ENDPOINT,
) => {
return new ImageGenerationSdk(session, endpoint);
};
5 changes: 5 additions & 0 deletions clients/ai-foundation_models-v1/sdk/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export * from './embeddingSdk';
export * from './imageGenerationSdk';
export * from './textClassificationSdk';
export * from './textGenerationSdk';
export * from '..';
73 changes: 73 additions & 0 deletions clients/ai-foundation_models-v1/sdk/textClassificationSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import { Client } from 'nice-grpc';
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';
import {
FewShotTextClassificationRequest,
TextClassificationRequest,
TextClassificationServiceService,
} from '../generated/yandex/cloud/ai/foundation_models/v1/text_classification/text_classification_service';
import { textClassificationService } from '..';

export type TextClassificationProps = Omit<
TypeFromProtoc<TextClassificationRequest, 'text'>,
'modelUri'
> & {
modelId: string;
folderId: string;
};

export type FewShotTextClassificationProps = Omit<
TypeFromProtoc<FewShotTextClassificationRequest, 'text'>,
'modelUri'
> & {
modelId: string;
folderId: string;
};

export class TextClassificationSdk {
private textClassificationClient: Client<
typeof TextClassificationServiceService,
ClientCallArgs
>;

static ENDPOINT = 'llm.api.cloud.yandex.net:443';

constructor(session: SessionArg, endpoint = TextClassificationSdk.ENDPOINT) {
this.textClassificationClient = session.client(
textClassificationService.TextClassificationServiceClient,
endpoint,
);
}

classifyText(params: TextClassificationProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.textClassificationClient.classify(
textClassificationService.TextClassificationRequest.fromPartial({
...restParams,
modelUri,
}),
args,
);
}

classifyTextFewShort(params: FewShotTextClassificationProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.textClassificationClient.fewShotClassify(
textClassificationService.FewShotTextClassificationRequest.fromPartial({
...restParams,
modelUri,
}),
args,
);
}
}

export const initTextClassificationSdk = (
session: SessionArg,
endpoint = TextClassificationSdk.ENDPOINT,
) => {
return new TextClassificationSdk(session, endpoint);
};
98 changes: 98 additions & 0 deletions clients/ai-foundation_models-v1/sdk/textGenerationSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import { Client } from 'nice-grpc';
import { textGenerationService } from '..';

import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';
import {
CompletionRequest,
TextGenerationAsyncServiceService,
TextGenerationServiceService,
TokenizeRequest,
TokenizerServiceService,
} from '../generated/yandex/cloud/ai/foundation_models/v1/text_generation/text_generation_service';

export type CompletionProps = Omit<TypeFromProtoc<CompletionRequest, 'messages'>, 'modelUri'> & {
modelId: string;
folderId: string;
};

export type TokenizeProps = Omit<TypeFromProtoc<TokenizeRequest, 'text'>, 'modelUri'> & {
modelId: string;
folderId: string;
};

export class TextGenerationSdk {
private textGenerationClient: Client<typeof TextGenerationServiceService, ClientCallArgs>;
private tokenizerClient: Client<typeof TokenizerServiceService, ClientCallArgs>;
private textGenerationAsyncClient: Client<
typeof TextGenerationAsyncServiceService,
ClientCallArgs
>;

static ENDPOINT = 'llm.api.cloud.yandex.net:443';

constructor(session: SessionArg, endpoint = TextGenerationSdk.ENDPOINT) {
this.textGenerationClient = session.client(
textGenerationService.TextGenerationServiceClient,
endpoint,
);

this.tokenizerClient = session.client(
textGenerationService.TokenizerServiceClient,
endpoint,
);

this.textGenerationAsyncClient = session.client(
textGenerationService.TextGenerationAsyncServiceClient,
endpoint,
);
}

tokenize(params: TokenizeProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.tokenizerClient.tokenize(
textGenerationService.TokenizeRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}

tokenizeCompletion(params: CompletionProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.tokenizerClient.tokenizeCompletion(
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}

completion(params: CompletionProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.textGenerationClient.completion(
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}

completionAsOperation(params: CompletionProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

const operationP = this.textGenerationAsyncClient.completion(
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }),
args,
);

return operationP;
}
}

export const initTextGenerationSdk = (
session: SessionArg,
endpoint = TextGenerationSdk.ENDPOINT,
) => {
return new TextGenerationSdk(session, endpoint);
};
Loading

0 comments on commit 2d8b293

Please sign in to comment.