Skip to content

Commit

Permalink
saving/loading catboost model
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Sep 5, 2023
1 parent 8b41c3c commit 0bf8d30
Showing 1 changed file with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional

import pandas as pd
Expand All @@ -8,6 +9,7 @@
from fedot.core.data.data_split import train_test_data_setup
from fedot.core.operations.evaluation.operation_implementations.implementation_interfaces import ModelImplementation
from fedot.core.operations.operation_parameters import OperationParameters
from fedot.core.utils import default_fedot_data_dir


class FedotCatBoostImplementation(ModelImplementation):
Expand Down Expand Up @@ -73,6 +75,14 @@ def convert_to_pool(data: Optional[InputData]):
feature_names=data.features_names.tolist()
)

def save_model(self, model_name: str = 'catboost'):
save_path = os.path.join(default_fedot_data_dir(), f'catboost/{model_name}.cbm')
self.model.save_model(save_path, format='cbm')

def load_model(self, path):
self.model = CatBoostClassifier()
self.model.load_model(path)


class FedotCatBoostClassificationImplementation(FedotCatBoostImplementation):
def __init__(self, params: Optional[OperationParameters] = None):
Expand Down

0 comments on commit 0bf8d30

Please sign in to comment.