Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mediapipe/tasks/web/genai/llm_inference/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mediapipe_ts_library(
deps = [
":llm_inference_types",
":model_loading_utils",
":efficient_model_loader",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/genai/inference/calculators:detokenizer_calculator_jspb_proto",
Expand Down Expand Up @@ -58,3 +59,11 @@ mediapipe_ts_library(
],
visibility = ["//visibility:public"],
)

mediapipe_ts_library(
name = "efficient_model_loader",
srcs = [
"efficient_model_loader.ts",
],
deps = [],
)
68 changes: 68 additions & 0 deletions mediapipe/tasks/web/genai/llm_inference/efficient_model_loader.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created new model class for handler

* Copyright 2025 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* Optimized model loading utilities for LLM inference.
*/

/**
* Creates a streaming model loader with proper resource management.
*/
export async function createModelStream(
modelAssetPath: string,
signal?: AbortSignal
): Promise<ReadableStreamDefaultReader<Uint8Array>> {
const response = await fetch(modelAssetPath, { signal });

if (!response.ok) {
throw new Error(
`Failed to fetch model: ${modelAssetPath} (${response.status})`
);
}

if (!response.body) {
throw new Error(
`Failed to fetch model: ${modelAssetPath} (no body)`
);
}

return response.body.getReader();
}

/**
* Model loader with cancellation support.
*/
export class ModelLoader {
private abortController?: AbortController;

async loadModel(
modelAssetPath: string
): Promise<ReadableStreamDefaultReader<Uint8Array>> {
this.cancel();
this.abortController = new AbortController();

return createModelStream(modelAssetPath, this.abortController.signal);
}

cancel(): void {
this.abortController?.abort();
this.abortController = undefined;
}

isLoading(): boolean {
return !!this.abortController;
}
}
142 changes: 142 additions & 0 deletions mediapipe/tasks/web/genai/llm_inference/efficient_model_loader_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/**
* Copyright 2025 The MediaPipe Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import 'jasmine';

import {
createModelStream,
ModelLoader,
} from './efficient_model_loader';

describe('EfficientModelLoader', () => {
let mockFetch: jasmine.Spy;
let originalFetch: typeof fetch;

beforeEach(() => {
originalFetch = globalThis.fetch;
mockFetch = jasmine.createSpy('fetch');
globalThis.fetch = mockFetch;
});

afterEach(() => {
globalThis.fetch = originalFetch;
});

describe('createModelStream', () => {
it('should create a stream from a successful fetch', async () => {
const mockData = new Uint8Array([1, 2, 3, 4, 5]);
const mockResponse = {
ok: true,
status: 200,
body: new ReadableStream({
start(controller) {
controller.enqueue(mockData);
controller.close();
},
}),
};
mockFetch.and.returnValue(Promise.resolve(mockResponse));

const stream = await createModelStream('http://example.com/model.bin');
const { value } = await stream.read();

expect(value).toEqual(mockData);
});

it('should throw error for failed fetch', async () => {
const mockResponse = {
ok: false,
status: 404,
};
mockFetch.and.returnValue(Promise.resolve(mockResponse));

await expectAsync(
createModelStream('http://example.com/nonexistent.bin')
).toBeRejectedWithError(/Failed to fetch model.*404/);
});
});



describe('ModelLoader', () => {
let loader: ModelLoader;

beforeEach(() => {
loader = new ModelLoader();
});

afterEach(() => {
loader.cancel();
});

it('should load a model successfully', async () => {
const mockData = new Uint8Array([1, 2, 3]);
const mockResponse = {
ok: true,
status: 200,
body: new ReadableStream({
start(controller) {
controller.enqueue(mockData);
controller.close();
},
}),
};
mockFetch.and.returnValue(Promise.resolve(mockResponse));

const stream = await loader.loadModel('http://example.com/model.bin');
const { value } = await stream.read();

expect(value).toEqual(mockData);
});

it('should track loading state', async () => {
mockFetch.and.returnValue(new Promise(() => {})); // Never resolves

expect(loader.isLoading()).toBeFalse();

const loadPromise = loader.loadModel('http://example.com/model.bin');
expect(loader.isLoading()).toBeTrue();

loader.cancel();
await expectAsync(loadPromise).toBeRejected();
expect(loader.isLoading()).toBeFalse();
});

it('should cancel previous loading when starting new load', async () => {
mockFetch.and.returnValue(new Promise(() => {})); // Never resolves

const firstLoad = loader.loadModel('http://example.com/model1.bin');
expect(loader.isLoading()).toBeTrue();

loader.cancel();
expect(loader.isLoading()).toBeFalse();

await expectAsync(firstLoad).toBeRejected();
});



it('should handle loading errors gracefully', async () => {
mockFetch.and.returnValue(Promise.reject(new Error('Network failure')));

await expectAsync(
loader.loadModel('http://example.com/model.bin')
).toBeRejected();

expect(loader.isLoading()).toBeFalse();
});
});
});
21 changes: 10 additions & 11 deletions mediapipe/tasks/web/genai/llm_inference/llm_inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import {
tee,
uint8ArrayToStream,
} from './model_loading_utils';
import { ModelLoader } from './efficient_model_loader';

export type {
Audio,
Expand Down Expand Up @@ -181,6 +182,7 @@ export class LlmInference extends TaskRunner {
private streamingReader?: StreamingReader;
private useLlmEngine = false;
private isConvertedModel = false;
private modelLoader = new ModelLoader();

// The WebGPU device used for LLM inference.
private wgpuDevice?: GPUDevice;
Expand Down Expand Up @@ -394,7 +396,7 @@ export class LlmInference extends TaskRunner {
override async setOptions(options: LlmInferenceOptions): Promise<void> {
// TODO: b/324482487 - Support customizing config for Web task of LLM
// Inference.
if (this.isProcessing) {
if (this.isProcessing || this.modelLoader.isLoading()) {
throw new Error('Cannot set options while loading or processing.');
}

Expand All @@ -414,20 +416,15 @@ export class LlmInference extends TaskRunner {

let modelStream: ReadableStreamDefaultReader<Uint8Array> | undefined;
if (options.baseOptions?.modelAssetPath) {
const request = await fetch(
options.baseOptions.modelAssetPath.toString(),
);
if (!request.ok) {
throw new Error(
`Failed to fetch model: ${options.baseOptions.modelAssetPath} (${request.status})`,
try {
modelStream = await this.modelLoader.loadModel(
options.baseOptions.modelAssetPath.toString()
);
}
if (!request.body) {
} catch (error) {
throw new Error(
`Failed to fetch model: ${options.baseOptions.modelAssetPath} (no body)`,
`Failed to load model from path: ${options.baseOptions.modelAssetPath}. ${error}`
);
}
modelStream = request.body.getReader();
} else if (options.baseOptions?.modelAssetBuffer instanceof Uint8Array) {
modelStream = uint8ArrayToStream(
options.baseOptions.modelAssetBuffer,
Expand Down Expand Up @@ -1385,6 +1382,8 @@ export class LlmInference extends TaskRunner {
}

override close() {
this.modelLoader.cancel();

if (this.useLlmEngine) {
(
this.graphRunner as unknown as LlmGraphRunner
Expand Down
25 changes: 25 additions & 0 deletions mediapipe/tasks/web/genai/llm_inference/llm_inference_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ describe('LlmInference', () => {
expect(llmInference).toBeDefined();
});

it('loads a model from modelAssetPath', async () => {
const pathOptions = {
baseOptions: { modelAssetPath: modelUrl },
numResponses: 1,
};

llmInference = await LlmInference.createFromOptions(genaiFileset, pathOptions);
expect(llmInference).toBeDefined();
});

it('handles modelAssetPath loading errors', async () => {
const invalidPathOptions = {
baseOptions: { modelAssetPath: 'http://invalid-url/nonexistent.bin' },
numResponses: 1,
};

await expectAsync(
LlmInference.createFromOptions(genaiFileset, invalidPathOptions)
).toBeRejectedWithError(/Failed to load model from path/);
});

it('loads a model, deletes it, and then loads it again', async () => {
llmInference = await load();

Expand Down Expand Up @@ -295,6 +316,8 @@ describe('LlmInference', () => {
}).toThrowError(/currently loading or processing/);
expect(typeof (await responsePromise)).toBe('string');
});


});

describe('running', () => {
Expand Down Expand Up @@ -387,6 +410,8 @@ describe('LlmInference', () => {
await expectAsync(responsePromise).toBeResolved();
expect(typeof (await responsePromise)).toBe('string');
});


});
});
}
Expand Down