@@ -48,9 +48,12 @@ def predict(self, features: np.ndarray) -> np.ndarray:
48
48
# Extract all prediction values and convert them to binary labels
49
49
prediction_values = [sublist ["True" ] for sublist in outputs [1 ]]
50
50
binary_labels = np .array (prediction_values ) >= threshold
51
- binary_labels = binary_labels .astype (int )
51
+ binary_labels = binary_labels .astype ("uint8" )
52
52
53
- return np .stack ([binary_labels , prediction_values ], axis = 0 ).astype (np .float32 )
53
+ prediction_values = np .array (prediction_values ) * 100.0
54
+ prediction_values = np .round (prediction_values ).astype ("uint8" )
55
+
56
+ return np .stack ([binary_labels , prediction_values ], axis = 0 )
54
57
55
58
def execute (self , inarr : xr .DataArray ) -> xr .DataArray :
56
59
classifier_url = self ._parameters .get ("classifier_url" , self .CATBOOST_PATH )
@@ -62,9 +65,9 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
62
65
self .onnx_session = self .load_ort_session (classifier_url )
63
66
64
67
# Run catboost classification
65
- self .logger .info (f "Catboost classification with input shape: { inarr .shape } " )
68
+ self .logger .info ("Catboost classification with input shape: %s" , inarr .shape )
66
69
classification = self .predict (inarr .values )
67
- self .logger .info (f "Classification done with shape: { classification .shape } " )
70
+ self .logger .info ("Classification done with shape: %s" , inarr .shape )
68
71
69
72
classification = xr .DataArray (
70
73
classification .reshape ((2 , len (x_coords ), len (y_coords ))),
0 commit comments