Skip to content

Commit 3cdef76

Browse files
lostellashchur
andauthored
Backports v0.15.1 (reprise) (#3191)
*Description of changes:* backporting fixes - #3188 - #3189 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. **Please tag this pr with at least one of these labels to make our release process faster:** BREAKING, new feature, bug fix, other change, dev setup --------- Co-authored-by: Oleksandr Shchur <[email protected]>
1 parent 0cb0808 commit 3cdef76

File tree

8 files changed

+43
-35
lines changed

8 files changed

+43
-35
lines changed

docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module):
110110
torch.nn.init.zeros_(lin.bias)
111111
return lin
112112

113-
def forward(self, context):
114-
scale = self.scaling(context)
115-
scaled_context = context / scale
116-
nn_out = self.nn(scaled_context)
113+
def forward(self, past_target):
114+
scale = self.scaling(past_target)
115+
scaled_past_target = past_target / scale
116+
nn_out = self.nn(scaled_past_target)
117117
nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1])
118118
distr_args = self.args_proj(nn_out_reshaped)
119119
return distr_args, torch.zeros_like(scale), scale
@@ -143,15 +143,15 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule):
143143
super().__init__(*args, **kwargs)
144144

145145
def training_step(self, batch, batch_idx):
146-
context = batch["past_target"]
147-
target = batch["future_target"]
146+
past_target = batch["past_target"]
147+
future_target = batch["future_target"]
148148

149-
assert context.shape[-1] == self.context_length
150-
assert target.shape[-1] == self.prediction_length
149+
assert past_target.shape[-1] == self.context_length
150+
assert future_target.shape[-1] == self.prediction_length
151151

152-
distr_args, loc, scale = self(context)
152+
distr_args, loc, scale = self(past_target)
153153
distr = self.distr_output.distribution(distr_args, loc, scale)
154-
loss = -distr.log_prob(target)
154+
loss = -distr.log_prob(future_target)
155155

156156
return loss.mean()
157157

src/gluonts/model/forecast_generator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:
8383

8484

8585
def make_predictions(prediction_net, inputs: dict):
86-
# MXNet predictors only support positional arguments
87-
class_name = prediction_net.__class__.__module__
88-
if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"):
89-
return prediction_net(*inputs.values())
90-
else:
91-
return prediction_net(**inputs)
86+
try:
87+
# Feed inputs as positional arguments for MXNet block predictors
88+
import mxnet as mx
89+
90+
if isinstance(prediction_net, mx.gluon.Block):
91+
return prediction_net(*inputs.values())
92+
except ImportError:
93+
pass
94+
return prediction_net(**inputs)
9295

9396

9497
class ForecastGenerator:

src/gluonts/torch/distributions/distribution_output.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,6 @@ def loss(
9090
nll = nll * (variance.detach() ** self.beta)
9191
return nll
9292

93-
@property
94-
def event_shape(self) -> Tuple:
95-
r"""
96-
Shape of each individual event contemplated by the distributions that
97-
this object constructs.
98-
"""
99-
raise NotImplementedError()
100-
10193
@property
10294
def event_dim(self) -> int:
10395
r"""

src/gluonts/torch/distributions/output.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def loss(
105105
"""
106106
raise NotImplementedError()
107107

108+
@property
109+
def event_shape(self) -> Tuple:
110+
r"""
111+
Shape of each individual event compatible with the output object.
112+
"""
113+
raise NotImplementedError()
114+
108115
@property
109116
def forecast_generator(self) -> ForecastGenerator:
110117
raise NotImplementedError()

src/gluonts/torch/distributions/quantile_output.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None:
3737
def forecast_generator(self) -> ForecastGenerator:
3838
return QuantileForecastGenerator(quantiles=self.quantiles)
3939

40+
@property
41+
def event_shape(self) -> Tuple:
42+
return ()
43+
4044
def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
4145
return args
4246

src/gluonts/torch/model/tide/estimator.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from gluonts.dataset.field_names import FieldName
2222
from gluonts.dataset.loader import as_stacked_batches
2323
from gluonts.itertools import Cyclic
24-
from gluonts.model.forecast_generator import DistributionForecastGenerator
2524
from gluonts.time_feature import (
2625
minute_of_hour,
2726
hour_of_day,
@@ -49,10 +48,7 @@
4948

5049
from gluonts.torch.model.estimator import PyTorchLightningEstimator
5150
from gluonts.torch.model.predictor import PyTorchPredictor
52-
from gluonts.torch.distributions import (
53-
DistributionOutput,
54-
StudentTOutput,
55-
)
51+
from gluonts.torch.distributions import Output, StudentTOutput
5652

5753
from .lightning_module import TiDELightningModule
5854

@@ -174,7 +170,7 @@ def __init__(
174170
weight_decay: float = 1e-8,
175171
patience: int = 10,
176172
scaling: Optional[str] = "mean",
177-
distr_output: DistributionOutput = StudentTOutput(),
173+
distr_output: Output = StudentTOutput(),
178174
batch_size: int = 32,
179175
num_batches_per_epoch: int = 50,
180176
trainer_kwargs: Optional[Dict[str, Any]] = None,
@@ -403,9 +399,7 @@ def create_predictor(
403399
input_transform=transformation + prediction_splitter,
404400
input_names=PREDICTION_INPUT_NAMES,
405401
prediction_net=module,
406-
forecast_generator=DistributionForecastGenerator(
407-
self.distr_output
408-
),
402+
forecast_generator=self.distr_output.forecast_generator,
409403
batch_size=self.batch_size,
410404
prediction_length=self.prediction_length,
411405
device="auto",

src/gluonts/torch/model/tide/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from gluonts.core.component import validated
2020
from gluonts.torch.modules.feature import FeatureEmbedder
2121
from gluonts.model import Input, InputSpec
22-
from gluonts.torch.distributions import DistributionOutput
22+
from gluonts.torch.distributions import Output
2323
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
2424
from gluonts.torch.model.simple_feedforward import make_linear_layer
2525
from gluonts.torch.util import weighted_average
@@ -242,7 +242,7 @@ def __init__(
242242
num_layers_encoder: int,
243243
num_layers_decoder: int,
244244
layer_norm: bool,
245-
distr_output: DistributionOutput,
245+
distr_output: Output,
246246
scaling: str,
247247
) -> None:
248248
super().__init__()

test/torch/model/test_estimators.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@
148148
num_batches_per_epoch=3,
149149
trainer_kwargs=dict(max_epochs=2),
150150
),
151+
lambda dataset: TiDEEstimator(
152+
freq=dataset.metadata.freq,
153+
prediction_length=dataset.metadata.prediction_length,
154+
distr_output=QuantileOutput(quantiles=[0.1, 0.6, 0.85]),
155+
batch_size=4,
156+
num_batches_per_epoch=3,
157+
trainer_kwargs=dict(max_epochs=2),
158+
),
151159
lambda dataset: WaveNetEstimator(
152160
freq=dataset.metadata.freq,
153161
prediction_length=dataset.metadata.prediction_length,

0 commit comments

Comments
 (0)