diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 3b32ba9c4f..77ec889ea2 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -974,9 +974,19 @@ impl ParsedOnnxGraph { fn prelu_conversion(node: Node) -> PReluNode { let input = TensorType::from(node.inputs.first().unwrap()); let output = TensorType::from(node.outputs.first().unwrap()); - let weight = extract_data_serialize::(1, &node).unwrap(); + let mut weight = extract_data_serialize::(1, &node).unwrap(); let config = PReluConfig::new(); let name = &node.name; + + if weight.shape.len() > 1 { + if weight.shape[1..].iter().product::() == 1 { + // Burn accepts rank 1 alpha weight + weight.shape = weight.shape[..1].to_vec(); + } else { + panic!("Invalid PRelu weight with shape {:?}", weight.shape); + } + } + PReluNode::new(name, input, output, weight, config) } fn conv_transpose2d_conversion(node: Node) -> ConvTranspose2dNode {