From 9cdd2507c81a83a356c07662c8b864fed81bb8a8 Mon Sep 17 00:00:00 2001 From: vinicvaz Date: Tue, 7 May 2024 14:32:12 -0300 Subject: [PATCH] fix --- pieces/ProphetPredictPiece/models.py | 4 ++-- pieces/ProphetPredictPiece/piece.py | 2 +- pieces/ProphetTrainModelPiece/models.py | 3 ++- pieces/ProphetTrainModelPiece/piece.py | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pieces/ProphetPredictPiece/models.py b/pieces/ProphetPredictPiece/models.py index fb150b5..bf84de0 100644 --- a/pieces/ProphetPredictPiece/models.py +++ b/pieces/ProphetPredictPiece/models.py @@ -5,8 +5,8 @@ class InputModel(BaseModel): - model_path: str = Field( - title="Model Path", + prophet_model_path: str = Field( + title="Prophet Model Path", description="Path to the file containing the trained model." ) periods: int = Field( diff --git a/pieces/ProphetPredictPiece/piece.py b/pieces/ProphetPredictPiece/piece.py index f5d50e5..6aa4b23 100644 --- a/pieces/ProphetPredictPiece/piece.py +++ b/pieces/ProphetPredictPiece/piece.py @@ -14,7 +14,7 @@ class ProphetPredictPiece(BasePiece): """ def piece_function(self, input_data: InputModel): - with open(input_data.model_path, "rb") as f: + with open(input_data.prophet_model_path, "rb") as f: model = pickle.load(f) future = model.make_future_dataframe(periods=input_data.periods) diff --git a/pieces/ProphetTrainModelPiece/models.py b/pieces/ProphetTrainModelPiece/models.py index b1abb87..ccb586e 100644 --- a/pieces/ProphetTrainModelPiece/models.py +++ b/pieces/ProphetTrainModelPiece/models.py @@ -53,7 +53,8 @@ class InputModel(BaseModel): class OutputModel(BaseModel): - model_file_path: str = Field( + prophet_model_file_path: str = Field( + title='Prophet model path', description="Path to the file containing the trained model." ) # results_figure_file_path: str = Field( diff --git a/pieces/ProphetTrainModelPiece/piece.py b/pieces/ProphetTrainModelPiece/piece.py index 7c83608..6af429b 100644 --- a/pieces/ProphetTrainModelPiece/piece.py +++ b/pieces/ProphetTrainModelPiece/piece.py @@ -24,10 +24,10 @@ def piece_function(self, input_data: InputModel): model.fit(df) # Serialize model - model_file_path = self.results_path / "prophet_model.json" + model_file_path = Path(self.results_path) / "prophet_model.json" with open(str(model_file_path), "wb") as f: pickle.dump(model, f) return OutputModel( - model_file_path=str(model_file_path), + prophet_model_file_path=str(model_file_path), ) \ No newline at end of file