You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
f"Past covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {input_length}), which will be padded with zeros at the beginning."
144
+
)
145
+
pad_size=input_length-cov_value.shape[0]
146
+
past_covariates[cov_key] =F.pad(
147
+
cov_value, (pad_size, 0)
148
+
)
149
+
else:
150
+
raiseValueError(
151
+
f"Individual `past_covariates` must be 1-d with length equal to the length of `target` (= {input_length}), found: {cov_key} with shape {tuple(cov_value.shape)} in element at index {idx}."
152
+
)
128
153
129
154
# Check 'future_covariates' if it exists (optional)
f"Each covariate in 'future_covariates' must have shape ({output_length},), but got shape {cov_value.shape} for key '{cov_key}' at index {idx}."
184
+
f"Individual `future_covariates` must be 1-d, found: {cov_key} with {cov_value.ndim} dimensions in element at index {idx}."
149
185
)
186
+
# If any future_covariate's length is not equal to output_length, process it accordingly.
187
+
ifcov_value.shape[0] !=output_length:
188
+
ifauto_adapt:
189
+
ifcov_value.shape[0] >output_length:
190
+
logger.warning(
191
+
f"Future covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (> {output_length}), which will be truncated from the end."
192
+
)
193
+
future_covariates[cov_key] =cov_value[
194
+
:output_length
195
+
]
196
+
else:
197
+
logger.warning(
198
+
f"Future covariate {cov_key} at index {idx} has length {cov_value.shape[0]} (< {output_length}), which will be padded with zeros at the end."
199
+
)
200
+
pad_size=output_length-cov_value.shape[0]
201
+
future_covariates[cov_key] =F.pad(
202
+
cov_value, (0, pad_size)
203
+
)
204
+
else:
205
+
raiseValueError(
206
+
f"Individual `future_covariates` must be 1-d with length equal to `output_length` (= {output_length}), found: {cov_key} with shape {tuple(cov_value.shape)} in element at index {idx}."
207
+
)
150
208
else:
151
209
raiseValueError(
152
210
f"The inputs must be a list of dictionaries, but got {type(inputs)}."
0 commit comments