Skip to content

Commit

Permalink
Merge pull request #96 from WorldCereal/infer-classes
Browse files Browse the repository at this point in the history
Infer classes from model
  • Loading branch information
kvantricht authored Jul 2, 2024
2 parents bef1a49 + d2d3e5b commit 3f44093
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions src/worldcereal/openeo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,13 @@ def predict(self, features: np.ndarray) -> np.ndarray:
# Prepare input data for ONNX model
outputs = self.onnx_session.run(None, {"features": features})

# Apply LUT: TODO:this needs an update!
LUT = {
"barley": 1,
"maize": 2,
"millet_sorghum": 3,
"other_crop": 4,
"rapeseed_rape": 5,
"soy_soybeans": 6,
"sunflower": 7,
"wheat": 8,
}
# Get info on classes from the model
class_params = eval(
self.onnx_session.get_modelmeta().custom_metadata_map["class_params"]
)

# Get classes LUT
LUT = dict(zip(class_params["class_names"], class_params["class_to_label"]))

# Extract classes as INTs and probability of winning class values
labels = np.zeros((len(outputs[0]),), dtype=np.uint16)
Expand Down

0 comments on commit 3f44093

Please sign in to comment.