Skip to content

Commit

Permalink
fix: prediction issue resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
Yash7824 committed Sep 1, 2024
1 parent e689192 commit a9b7b23
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/app/services/chess-predictor.service.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Injectable } from '@angular/core';
import * as ort from 'onnxruntime-web';
import { environment } from 'src/environments/environment';

@Injectable({
providedIn: 'root'
Expand All @@ -9,13 +10,28 @@ export class ChessPredictorService {
private model: ort.InferenceSession | null = null;

constructor() {
this.initializeBackend();
this.loadModel();
}

private async initializeBackend() {
try {
// Set the correct paths to the WASM files
ort.env.wasm.wasmPaths = '/assets/onnxruntime/';
console.log('WASM paths set to:', ort.env.wasm.wasmPaths);
} catch (err) {
console.error('Failed to set WASM paths:', err);
}
}

// Load the ONNX model from the assets folder
private async loadModel() {
try {
this.model = await ort.InferenceSession.create('/assets/neural_network_model/chess_move_predictor.onnx');
const response = await fetch(`${environment.base_url}/nn_model/getNNModel`);
const arrayBuffer = await response.arrayBuffer();
const modelTensor = new Uint8Array(arrayBuffer);

this.model = await ort.InferenceSession.create(modelTensor);
console.log('ONNX Model loaded successfully.');
} catch (error) {
console.error('Error loading ONNX model:', error);
Expand All @@ -24,6 +40,7 @@ export class ChessPredictorService {

// Method to preprocess input data and make predictions
public async predict(boardState: Float32Array): Promise<number | null> {
debugger;
if (!this.model) {
console.error('Model is not loaded yet.');
return null;
Expand All @@ -35,10 +52,12 @@ export class ChessPredictorService {
try {
const feeds: ort.InferenceSession.OnnxValueMapType = { input: inputTensor };
const results = await this.model.run(feeds);
const output = results['output'].data;

// Assert the type explicitly to number
const predictedMoveIndex = output[0] as number; // Treat output[0] as a number
const outputTensor = results['output']; // Make sure this matches your ONNX model output name
const outputArray = Array.from(outputTensor.data as Float32Array);

// Find the index with the maximum value
const predictedMoveIndex = outputArray.indexOf(Math.max(...outputArray));

return predictedMoveIndex;
} catch (error) {
console.error('Error during model inference:', error);
Expand Down
Binary file not shown.

0 comments on commit a9b7b23

Please sign in to comment.