-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeError: add got incompatible shapes for broadcasting: (58,), (54,). #309
Comments
Even i am getting the issue, looking for the solution for it |
Installing an older version of numpyro resolved my issue |
I had the same problem and 0.13.2 version of numpyro was not working for me so I used the following command to install numpyro while installing mmm, matplotlib etc: !pip install numpyro==0.13.1 |
I am also facing the same problem. Appreciate if anyone has solution for this. Thanks |
just install an older version of numpyro as stated in the comments above |
When i install an older version of numpyro, I have following issues with import . Any idea how to solve this? ModuleNotFoundError Traceback (most recent call last) ~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\preprocessing.py in ~\Anaconda3\envs\python3\lib\site-packages\lightweight_mmm\core\core_utils.py in ~\Anaconda3\envs\python3\lib\site-packages\numpyro_init_.py in ~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer_init_.py in ~\Anaconda3\envs\python3\lib\site-packages\numpyro\infer\elbo.py in ~\Anaconda3\envs\python3\lib\site-packages\numpyro\ops\provenance.py in ModuleNotFoundError: No module named 'jax.extend.linear_util' |
install an older version of jax. |
Sorry for the breakage. Could you try
|
TypeError Traceback (most recent call last)
in <cell line: 2>()
4 seed=SEED)
5 else:
----> 6 new_predictions = mmm.predict(media=media_scaler.transform(media_data_test),
7 extra_features=extra_features_scaler.transform(extra_features_test),
8 seed=SEED)
17 frames
/usr/local/lib/python3.10/dist-packages/lightweight_mmm/lightweight_mmm.py in predict(self, media, extra_features, media_gap, target_scaler, seed)
518 if seed is None:
519 seed = utils.get_time_seed()
--> 520 prediction = self._predict(
521 rng_key=jax.random.PRNGKey(seed=seed),
522 media_data=full_media,
/usr/local/lib/python3.10/dist-packages/lightweight_mmm/lightweight_mmm.py in _predict(self, rng_key, media_data, extra_features, media_prior, degrees_seasonality, frequency, transform_function, weekday_seasonality, model, posterior_samples, custom_priors)
441 The predictions for the given data.
442 """
--> 443 return infer.Predictive(
444 model=model, posterior_samples=posterior_samples)(
445 rng_key=rng_key,
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in call(self, rng_key, *args, **kwargs)
1009 """
1010 if self.batch_ndims == 0 or self.params == {} or self.guide is None:
-> 1011 return self._call_with_params(rng_key, self.params, args, kwargs)
1012 elif self.batch_ndims == 1: # batch over parameters
1013 batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in _call_with_params(self, rng_key, params, args, kwargs)
986 )
987 model = substitute(self.model, self.params)
--> 988 return _predictive(
989 rng_key,
990 model,
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
823 rng_key = rng_key.reshape(batch_shape + key_shape)
824 chunk_size = num_samples if parallel else 1
--> 825 return soft_vmap(
826 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
827 )
/usr/local/lib/python3.10/dist-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)
417 fn = vmap(fn)
418
--> 419 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
420 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
421 ys = tree_map(
/usr/local/lib/python3.10/dist-packages/numpyro/infer/util.py in single_prediction(val)
796 )
797 else:
--> 798 model_trace = trace(
799 seed(substitute(masked_model, samples), rng_key)
800 ).get_trace(*model_args, **model_kwargs)
/usr/local/lib/python3.10/dist-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return:
OrderedDict
containing the execution trace.170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.10/dist-packages/numpyro/primitives.py in call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
106
107
/usr/local/lib/python3.10/dist-packages/lightweight_mmm/models.py in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
410 # expo_trend is B(1, 1) so that the exponent on time is in [.5, 1.5].
411 prediction = (
--> 412 intercept + coef_trend * trend ** expo_trend +
413 seasonality * coef_seasonality +
414 jnp.einsum(media_einsum, media_transformed, coef_media))
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in op(self, *args)
741 def forward_operator_to_aval(name):
742 def op(self, *args):
--> 743 return getattr(self.aval, f"{name}")(self, *args)
744 return op
745
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/array_methods.py in deferring_binary_op(self, other)
269 args = (other, self) if swap else (self, other)
270 if isinstance(other, _accepted_binop_types):
--> 271 return binary_op(*args)
272 # Note: don't use isinstance here, because we don't want to raise for
273 # subclasses, e.g. NamedTuple objects that may override operators.
/usr/local/lib/python3.10/dist-packages/jax/src/numpy/ufuncs.py in fn(x1, x2)
97 def fn(x1, x2, /):
98 x1, x2 = promote_args(numpy_fn.name, x1, x2)
---> 99 return lax_fn(x1, x2) if x1.dtype != np.bool else bool_lax_fn(x1, x2)
100 fn.qualname = f"jax.numpy.{numpy_fn.name}"
101 fn = jit(fn, inline=True)
/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py in broadcasting_shape_rule(name, *avals)
1597 result_shape.append(non_1s[0])
1598 else:
-> 1599 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1600 f'{", ".join(map(str, map(tuple, shapes)))}.')
1601
TypeError: add got incompatible shapes for broadcasting: (58,), (54,).
The text was updated successfully, but these errors were encountered: