Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PARModel validating arguments in README example #27

Open
sarahmish opened this issue Jul 7, 2021 · 0 comments
Open

PARModel validating arguments in README example #27

sarahmish opened this issue Jul 7, 2021 · 0 comments
Labels
bug Something isn't working

Comments

@sarahmish
Copy link

Environment Details

Please indicate the following details about the environment in which you found the bug:

  • DeepEcho version: 0.2.0
  • Python version: 3.7
  • Operating System: macOS 10.15.7

Error Description

I tried to run the quickstart example in the README.md but it does not work on the PARModel on pytorch==1.9; the BasicGANModel works fine.
It seems like dist.log_prob is validating the argument by default which in negative cases is invalid. In previous versions (pytorch==1.7) this wasn't the case and it will return None when such values are encountered.

Steps to reproduce

from deepecho import PARModel
from deepecho.demo import load_demo

# Load demo data
data = load_demo()

# Define data types for all the columns
data_types = {
    'region': 'categorical',
    'day_of_week': 'categorical',
    'total_sales': 'continuous',
    'nb_customers': 'count',
}

model = PARModel(cuda=False)

# Learn a model from the data
model.fit(
    data=data,
    entity_columns=['store_id'],
    context_columns=['region'],
    data_types=data_types,
    sequence_index='date'
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-ca86fcdf00ef> in <module>
     21     context_columns=['region'],
     22     data_types=data_types,
---> 23     sequence_index='date'
     24 )
     25 

~/opt/anaconda3/envs/echo2/lib/python3.7/site-packages/deepecho/models/base.py in fit(self, data, entity_columns, context_columns, data_types, segment_size, sequence_index)
    166         # Validate and fit
    167         self._validate(sequences, context_types, data_types)
--> 168         self.fit_sequences(sequences, context_types, data_types)
    169 
    170         # Store context values

~/opt/anaconda3/envs/echo2/lib/python3.7/site-packages/deepecho/models/par.py in fit_sequences(self, sequences, context_types, data_types)
    330 
    331             optimizer.zero_grad()
--> 332             loss = self._compute_loss(X_padded[1:, :, :], Y_padded[:-1, :, :], seq_len)
    333             loss.backward()
    334             if self.verbose:

~/opt/anaconda3/envs/echo2/lib/python3.7/site-packages/deepecho/models/par.py in _compute_loss(self, X_padded, Y_padded, seq_len)
    387                     dist = torch.distributions.negative_binomial.NegativeBinomial(
    388                         r[:seq_len[i], i], p[:seq_len[i], i])
--> 389                     log_likelihood += torch.sum(dist.log_prob(x[:seq_len[i], i]))
    390 
    391                     p_true = X_padded[:seq_len[i], i, missing_idx]

~/opt/anaconda3/envs/echo2/lib/python3.7/site-packages/torch/distributions/negative_binomial.py in log_prob(self, value)
     90     def log_prob(self, value):
     91         if self._validate_args:
---> 92             self._validate_sample(value)
     93 
     94         log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) +

~/opt/anaconda3/envs/echo2/lib/python3.7/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    275         assert support is not None
    276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
    278 
    279     def _get_checked_instance(self, cls, _instance=None):

ValueError: The value argument must be within the support
@sarahmish sarahmish added the bug Something isn't working label Jul 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant