Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed May 7, 2024
1 parent 0730cc8 commit 9cdd250
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
4 changes: 2 additions & 2 deletions pieces/ProphetPredictPiece/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class InputModel(BaseModel):
model_path: str = Field(
title="Model Path",
prophet_model_path: str = Field(
title="Prophet Model Path",
description="Path to the file containing the trained model."
)
periods: int = Field(
Expand Down
2 changes: 1 addition & 1 deletion pieces/ProphetPredictPiece/piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ProphetPredictPiece(BasePiece):
"""
def piece_function(self, input_data: InputModel):

with open(input_data.model_path, "rb") as f:
with open(input_data.prophet_model_path, "rb") as f:
model = pickle.load(f)

future = model.make_future_dataframe(periods=input_data.periods)
Expand Down
3 changes: 2 additions & 1 deletion pieces/ProphetTrainModelPiece/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ class InputModel(BaseModel):


class OutputModel(BaseModel):
model_file_path: str = Field(
prophet_model_file_path: str = Field(
title='Prophet model path',
description="Path to the file containing the trained model."
)
# results_figure_file_path: str = Field(
Expand Down
4 changes: 2 additions & 2 deletions pieces/ProphetTrainModelPiece/piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def piece_function(self, input_data: InputModel):
model.fit(df)

# Serialize model
model_file_path = self.results_path / "prophet_model.json"
model_file_path = Path(self.results_path) / "prophet_model.json"
with open(str(model_file_path), "wb") as f:
pickle.dump(model, f)

return OutputModel(
model_file_path=str(model_file_path),
prophet_model_file_path=str(model_file_path),
)

0 comments on commit 9cdd250

Please sign in to comment.