Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluating the gradient within the log probability function? #29

Open
nmudur opened this issue Mar 10, 2024 · 0 comments
Open

Evaluating the gradient within the log probability function? #29

nmudur opened this issue Mar 10, 2024 · 0 comments

Comments

@nmudur
Copy link

nmudur commented Mar 10, 2024

I have a neural network-based log probability function $log p_{NN}(\theta|x, \vec{t})$. If I increase the size of $\vec{t}$ my code essentially creates a batch of x repeated len($\vec{t}$) times. While I am able to refactor my code so I compute this log probability and add it in smaller batches of len($\vec{t}_{BATCH}$), I still run into memory issues while computing the gradient.

I further refactored the log probability function code to accumulate the gradients while evaluating the log probability function so it returns a log probability as well as its gradient with respect to the parameters $\theta$. Now, the pass_grad argument seems to only accommodate a constant tensor or function that returns a tensor of dimension D. The NN-based log probability is also stochastic so I cant wrap the gradient as a separate function and pass it separately.

I would ideally like to restructure the code so as to evaluate the gradients when it evaluates the log probability -- I was going to modify my local hamiltorch package to do this, but I first thought I'd check if there's already a function in the package that handles this or a better workaround, in case other users have encountered this before?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant