Skip to content

Commit

Permalink
Fixed order of input parameters for onnx export
Browse files Browse the repository at this point in the history
  • Loading branch information
rolshoven authored and Luca Rolshoven committed Aug 28, 2023
1 parent 4ebee43 commit 1ca3207
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path,
for output_name in output_names:
dynamic_axes_output[output_name] = {0: "batch_size"}

# Move inputs to the right device
# Move inputs to the right device and put them in the right order
forward_params = tuple(signature(setfit_model.model_body.forward).parameters.keys()) # keys of ordered dict are ordered
ordered_kwargs = sorted(inputs.items(), key=lambda param: forward_params.index(param[0]))
odered_params = [param_value for (_, param_value) in ordered_kwargs]
target = setfit_model.model_body.device
args = tuple(value.to(target) for value in inputs.values())
args = tuple(value.to(target) for value in odered_params)

setfit_model.eval()
with torch.no_grad():
Expand Down

0 comments on commit 1ca3207

Please sign in to comment.