@@ -90,15 +90,13 @@ def predict(self,
90
90
horizon_timestamps_udf (hist_df .ds ).alias ("ds" ),
91
91
forecast_udf (hist_df .y ).alias ("y" ))
92
92
).toPandas ()
93
-
94
93
forecast_df = forecast_df .reset_index (drop = False ).rename (
95
94
columns = {
96
95
"unique_id" : self .params .group_id ,
97
96
"ds" : self .params .date_col ,
98
97
"y" : self .params .target ,
99
98
}
100
99
)
101
-
102
100
# Todo
103
101
# forecast_df[self.params.target] = forecast_df[self.params.target].clip(0.01)
104
102
return forecast_df , self .model
@@ -165,19 +163,13 @@ def predict_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
165
163
import numpy as np
166
164
import pandas as pd
167
165
# Initialize the ChronosPipeline with a pretrained model from the specified repository
168
- from chronos import BaseChronosPipeline , ChronosBoltPipeline
169
- if "bolt" in self .repo :
170
- pipeline = ChronosBoltPipeline .from_pretrained (
171
- self .repo ,
172
- device_map = self .device ,
173
- torch_dtype = torch .bfloat16 ,
174
- )
175
- else :
176
- pipeline = BaseChronosPipeline .from_pretrained (
177
- self .repo ,
178
- device_map = self .device ,
179
- torch_dtype = torch .bfloat16 ,
180
- )
166
+ from chronos import BaseChronosPipeline
167
+ pipeline = BaseChronosPipeline .from_pretrained (
168
+ self .repo ,
169
+ device_map = 'cuda' ,
170
+ torch_dtype = torch .bfloat16 ,
171
+ )
172
+
181
173
# inference
182
174
for bulk in bulk_iterator :
183
175
median = []
@@ -262,19 +254,12 @@ def __init__(self, repository, prediction_length):
262
254
self .prediction_length = prediction_length
263
255
self .device = "cuda" if torch .cuda .is_available () else "cpu"
264
256
# Initialize the ChronosPipeline with a pretrained model from the specified repository
265
- from chronos import BaseChronosPipeline , ChronosBoltPipeline
266
- if "bolt" in self .repository :
267
- self .pipeline = ChronosBoltPipeline .from_pretrained (
268
- self .repository ,
269
- device_map = self .device ,
270
- torch_dtype = torch .bfloat16 ,
271
- )
272
- else :
273
- self .pipeline = BaseChronosPipeline .from_pretrained (
274
- self .repository ,
275
- device_map = self .device ,
276
- torch_dtype = torch .bfloat16 ,
277
- )
257
+ from chronos import BaseChronosPipeline
258
+ self .pipeline = BaseChronosPipeline .from_pretrained (
259
+ self .repository ,
260
+ device_map = 'cuda' ,
261
+ torch_dtype = torch .bfloat16 ,
262
+ )
278
263
279
264
def predict (self , context , input_data , params = None ):
280
265
history = [torch .tensor (list (series )) for series in input_data ]
0 commit comments