-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding sampling functionality to linearregression model #16
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
api | ||
pymc3_models | ||
============ | ||
|
||
.. toctree:: | ||
:maxdepth: 4 | ||
|
||
pymc3_models.models | ||
pymc3_models |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
pymc3\_models package | ||
===================== | ||
|
||
Subpackages | ||
----------- | ||
|
||
.. toctree:: | ||
|
||
pymc3_models.models | ||
|
||
Submodules | ||
---------- | ||
|
||
pymc3\_models\.exc module | ||
------------------------- | ||
|
||
.. automodule:: pymc3_models.exc | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
|
||
Module contents | ||
--------------- | ||
|
||
.. automodule:: pymc3_models | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,17 +93,17 @@ def fit(self, X, y, inference_type='advi', minibatch_size=None, inference_args=N | |
|
||
return self | ||
|
||
def predict(self, X, return_std=False): | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only need one blank line between methods per PEP8. |
||
def sample(self, X, samples=2000): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about calling this something like |
||
""" | ||
Predicts values of new data with a trained Linear Regression model | ||
samples the conditional posterior estimates | ||
|
||
Parameters | ||
---------- | ||
X : numpy array, shape [n_samples, n_features] | ||
|
||
return_std : Boolean flag of whether to return standard deviations with mean values. Defaults to False. | ||
samples : number of draws to make for each point | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add in the default value to this docstring? Like I do in the predict method below with |
||
""" | ||
|
||
if self.trace is None: | ||
raise PyMC3ModelsError('Run fit on the model before predict.') | ||
|
||
|
@@ -116,6 +116,23 @@ def predict(self, X, return_std=False): | |
|
||
ppc = pm.sample_ppc(self.trace, model=self.cached_model, samples=2000) | ||
|
||
return ppc | ||
|
||
def predict(self, X, return_std=False, samples=2000): | ||
""" | ||
Predicts values of new data with a trained Linear Regression model | ||
|
||
Parameters | ||
---------- | ||
X : numpy array, shape [n_samples, n_features] | ||
|
||
return_std : Boolean flag of whether to return standard deviations with mean values. Defaults to False. | ||
|
||
samples: numberof draws to make for each input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: |
||
""" | ||
|
||
ppc = self.sample(X, samples) | ||
|
||
if return_std: | ||
return ppc['y'].mean(axis=0), ppc['y'].std(axis=0) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I'm not sure if you're familiar with semantic versioning; the last number is reserved for bug fixes.
This is not a bug fix, so the version should be a minor version change, e.g.
1.2.0