-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_onnx.py
37 lines (30 loc) · 1.2 KB
/
run_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import onnxruntime as ort
import torch
import numpy as np
# Define the device (GPU if available, otherwise CPU)
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
def run_onnx_model(onnx_file, input_data):
ort_session = ort.InferenceSession(onnx_file)
outputs = ort_session.run(
None, {"input": input_data.cpu().numpy()}
) # Move to CPU before conversion
return outputs[0]
if __name__ == "__main__":
# Load the model
onnx_file = "neutron_detection.onnx"
sample_data = np.load("data.npy")
# Test with real data (should output positive for nuke)
real_data = (
torch.tensor(sample_data[142:143], dtype=torch.float32).unsqueeze(1).to(device)
)
real_output = run_onnx_model(onnx_file, real_data)
print(f"Real Data Output: {real_output[0]} (Positive for Nuke)")
# Test with random invalid data (should output negative)
invalid_data = torch.rand(1, 1, 10, 10).to(device) # Create random invalid data
invalid_output = run_onnx_model(onnx_file, invalid_data)
print(f"Invalid Data Output: {invalid_output[0]} (Negative)")