Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any plan to support torch.func? #1796

Closed
xhsonny opened this issue May 14, 2024 · 11 comments
Closed

Any plan to support torch.func? #1796

xhsonny opened this issue May 14, 2024 · 11 comments

Comments

@xhsonny
Copy link

xhsonny commented May 14, 2024

I am trying to compute full jacobian using jacrev or jacfwd from torch.func. Part of the loss function uses _PointFaceDistance. Out of the box, pytorch3d does not support torch.func. The closest references I can find so far are #1636 and #1533.

The problems I am having are

The only working method is to call torch.autograd.functional.jacobian(vectorize=False) which is very slow. And when turn on vectorize=True, it runs into the same issues as above.

My questions are:

  1. is there a plan to officially support torch.func ? If I can get some guidance from pytorch3d team, I am happy to collaborate on this.
  2. Any idea how to make this work? Any workarounds?

Thanks!

@bottler
Copy link
Contributor

bottler commented May 17, 2024

We aren't planning torch.func support. It seems to me that the method in #1533 should work fine for _PointFaceDistance - feel free to post the code you have and maybe we can figure out what's wrong.

@xhsonny
Copy link
Author

xhsonny commented May 17, 2024

@bottler Thanks for the reply!

The error msg "RuntimeError: Cannot access data pointer of Tensor that doesn't have storage" from #1533 happened in the forward pass "idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version). I don't know how #1533 is computing the jacobian. In my case, neither forward mode nor backward mode work with following errors:

  • With backward mode jacrev, the same error msg happens at the time when calling _C.point_face_dist_backward inside of backward.
  • With forward mode jacfwd, it complains about missing jvp method.

I did follow #1533 to modify the class with setup_context etc. I also added vmap but vmap method is never called before hitting two errors above.

@xhsonny
Copy link
Author

xhsonny commented May 22, 2024

@bottler I wrote a toy example that follows this pytorch3d tutorial with following modifications:

  1. Only use _PointFaceDistance in the objective. Because I only care if we can compute jacobian, it does not matter if the optimization actually runs.
  2. Added vmap to _PointFaceDistance and added setup_context.
  3. Used theseus

Here is the code. It is self-contained and will download dolphin.obj following the pytorch3d tutorial. Sorry that the code is a bit long to include _PointFaceDistance updates.

import os
import urllib.request

import einops
import theseus as th
import torch
from pytorch3d import _C
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from torch.autograd import Function
from torch.autograd.function import once_differentiable

_DEFAULT_MIN_TRIANGLE_AREA: float = 5e-3


# PointFaceDistance
class _PointFaceDistance(Function):
    """
    Torch autograd Function wrapper PointFaceDistance Cuda implementation
    """

    generate_vmap_rule = False

    @staticmethod
    def forward(
        # ctx,
        points,
        points_first_idx,
        tris,
        tris_first_idx,
        max_points,
        min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
    ):
        """
        Args:
            ctx: Context object used to calculate gradients.
            points: FloatTensor of shape `(P, 3)`
            points_first_idx: LongTensor of shape `(N,)` indicating the first point
                index in each example in the batch
            tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th
                triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])`
            tris_first_idx: LongTensor of shape `(N,)` indicating the first face
                index in each example in the batch
            max_points: Scalar equal to maximum number of points in the batch
            min_triangle_area: (float, defaulted) Triangles of area less than this
                will be treated as points/lines.
        Returns:
            dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
                euclidean distance of `p`-th point to the closest triangular face
                in the corresponding example in the batch
            idxs: LongTensor of shape `(P,)` indicating the closest triangular face
                in the corresponding example in the batch.

            `dists[p]` is
            `d(points[p], tris[idxs[p], 0], tris[idxs[p], 1], tris[idxs[p], 2])`
            where `d(u, v0, v1, v2)` is the distance of point `u` from the triangular
            face `(v0, v1, v2)`

        """
        dists, idxs = _C.point_face_dist_forward(
            points,
            points_first_idx,
            tris,
            tris_first_idx,
            max_points,
            min_triangle_area,
        )
        # ctx.save_for_backward(points, tris, idxs)
        # ctx.min_triangle_area = min_triangle_area
        return dists, idxs

    @staticmethod
    def setup_context(ctx, inputs, output):
        (
            points,
            points_first_idx,
            tris,
            tris_first_idx,
            max_tris,
            min_triangle_area,
        ) = inputs
        dists, idxs = output
        ctx.save_for_backward(points, tris, idxs)
        ctx.min_triangle_area = min_triangle_area
        ctx.dists = dists
        ctx.idxs = idxs
        ctx.inputs = inputs

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_dists, grad_idxs):
        grad_dists = grad_dists.contiguous()
        points, tris, idxs = ctx.saved_tensors
        min_triangle_area = ctx.min_triangle_area
        grad_points, grad_tris = _C.point_face_dist_backward(
            points, tris, idxs, grad_dists, min_triangle_area
        )
        return grad_points, None, grad_tris, None, None, None

    @staticmethod
    def vmap(
        info,
        in_idms,
        points,
        points_first_idx,
        tris,
        tris_first_idx,
        max_points,
        min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
    ):
        (
            points_bdim,
            points_first_idx_bdim,
            tris_bdm,
            tris_first_idx_bdim,
            _,
            _,
        ) = in_idms

        points_V, points_P, points_C = points.shape
        points = einops.rearrange(points, "V P C -> (V P) C")

        tris = einops.rearrange(tris, "V T A B -> (V T) A B")

        dists, idx = _PointFaceDistance.forward(
            points,
            points_first_idx,
            tris,
            tris_first_idx,
            max_points,
            min_triangle_area,
        )
        dists = einops.rearrange(dists, "(V P) -> V P", V=points_V)
        idx = einops.rearrange(idx, "(V P) -> V P", V=points_V)
        return (dists, idx), (0, 0)


point_face_distance = _PointFaceDistance.apply


def point_to_mesh_distance(points, mesh_v, mesh_f):

    scale_fac = 100.0  # see explanation above

    # packed representation for pointclouds
    points = points * scale_fac  # (P, 3)
    points_first_idx = torch.zeros([1])
    max_points = points.shape[0]

    # packed representation for faces
    verts_packed = mesh_v * scale_fac
    faces_packed = mesh_f
    tris = verts_packed[faces_packed.to(torch.int)]
    tris_first_idx = torch.zeros([1])

    point_to_face, _ = point_face_distance(
        points.to(torch.float32),
        points_first_idx.to(torch.long),
        tris.to(torch.float32),
        tris_first_idx.to(torch.long),
        max_points,
    )
    point_to_face = point_to_face / (scale_fac**2)
    return torch.sqrt(point_to_face)


device = "cpu"
target_obj_path = "dolphin.obj"
if not os.path.exists(target_obj_path):
    # Reference: https://pytorch3d.org/tutorials/deform_source_mesh_to_target_mesh
    src_url = (
        "https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj"
    )
    print(f"Downloading from {src_url}")
    urllib.request.urlretrieve(
        src_url,
        "dolphin.obj",
    )
verts, faces, aux = load_obj(target_obj_path)
faces_idx = faces.verts_idx.to(device)
verts = verts.to(device)

center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale


target_mesh = Meshes(verts=[verts], faces=[faces_idx])
src_mesh = ico_sphere(4, device)

deform_verts = th.Vector(
    tensor=src_mesh.verts_packed().reshape(1, -1),
    name="deform_v",
)

target_v = th.Variable(verts.reshape(1, -1), name="target_v")
faces_idx = faces_idx.to(torch.float32)
target_f = th.Variable(faces_idx.reshape(1, -1), name="target_f")


def error_fn(optim_vars, aux_vars):
    (verts,) = optim_vars
    target_v, target_f = aux_vars
    p2m = point_to_mesh_distance(
        verts.tensor.reshape(-1, 3).to(torch.float32),
        mesh_v=target_v.tensor.reshape(-1, 3).to(torch.float32),
        mesh_f=target_f.tensor.reshape(-1, 3),
    ).to(torch.float64)
    return p2m.unsqueeze(0)


optim_vars = (deform_verts,)
aux_vars = target_v, target_f
cost_function = th.AutoDiffCostFunction(
    optim_vars,
    error_fn,
    deform_verts.shape[1] / 3,
    aux_vars=aux_vars,
    name="l2",
)


# grad_points, grad_tris = _C.point_face_dist_backward(
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
cost_function.jacobians()

When cost_function.jacobians() is called, it throws an exception. Full error below:

Traceback (most recent call last):
  File "/Users/sonny/jac_theseus.py", line 227, in <module>
    cost_function.jacobians()
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 355, in jacobians
    jacobians_full = self._compute_autograd_jacobian_vmap(
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap
    return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
    return _flat_vmap(
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 44, in fn
    return f(*args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 609, in wrapper_fn
    flat_jacobians_per_input = compute_jacobian_stacked()
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 540, in compute_jacobian_stacked
    chunked_result = vmap(vjp_fn)(basis)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
    return _flat_vmap(
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 44, in fn
    return f(*args, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 336, in wrapper
    result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 124, in _autograd_grad
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/__init__.py", line 411, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/_functorch/autograd_function.py", line 123, in backward
    result = autograd_function.backward(ctx, *grads)
  File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/torch/autograd/function.py", line 570, in wrapper
    outputs = fn(ctx, *args)
  File "/Users/sonny/jac_theseus.py", line 96, in backward
    grad_points, grad_tris = _C.point_face_dist_backward(
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

I don't know how to start debugging this as _C.point_face_dist_backward happens in CUDA/CPU code. If you have any pointers, please let me know. Thanks a lot!

@TimoRST
Copy link

TimoRST commented May 23, 2024

Just in intuition based on the traceback, especially here:

File "/usr/local/anaconda3/envs/pytorch3d_torch2_v2/lib/python3.10/site-packages/theseus/core/cost_function.py", line 341, in _compute_autograd_jacobian_vmap
return vmap(jacrev(jac_fn, argnums=0))(optim_tensors, aux_tensors)

I think that jac_fn will be called in the same way that the forward path is called with vmap. So, there might be an error in the jac_fn function's vmapping, which you might specify in the same way as you did for the forward path.

@xhsonny
Copy link
Author

xhsonny commented May 23, 2024

@TimoRST Thanks for the inputs! I never thought of it and will look into it.

Just one follow up, jac_fn is either the cost_function which is th.AutoDiffCostFunction or the error_fn I wrote that I used to compute the loss values. Do you know how to add vmap to those functions? They are not torch.autograd.Function.

Thanks!

@TimoRST
Copy link

TimoRST commented May 23, 2024

I don't know that. I would first debug into that function to see if it really is called with those batched tensors which cause the error.
After verifying you might just make an autodiff function out of that function?

@xhsonny
Copy link
Author

xhsonny commented May 24, 2024

@TimoRST Thanks. I will try it out.

Can I ask what your use case was to use knn and theseus? Was it also used in the context of an optimization that needed a full jacobian?

Thanks!

@TimoRST
Copy link

TimoRST commented May 28, 2024

I wanted to implement something like WGICP (https://arxiv.org/abs/2209.09777), but I couldn't scale it because my graphics card was too small, so I didn't get comparable results.
I didn't use the Jacobian, so it was enough to ensure correct vmapping in the forward path.

@xhsonny
Copy link
Author

xhsonny commented May 28, 2024

@TimoRST Thanks for the details. Really appreciate the help!

@bottler could you take a look at the code I posted above? I am going to follow @TimoRST suggestion to take a look at the error function. Meanwhile you find anything in my code, please let me know. Thanks!

@xhsonny
Copy link
Author

xhsonny commented May 29, 2024

Here are some findings. Conclusion: _C.point_face_dist_backward cannot take in grad_dists that is a 2D BatchedVector which is created by vmap/jacrev.

  • When calling torch.autograd.functional.jacobian, it computes jacobian row-by-row. grad_dists is a 1D vector. Everything works fine.
  • When calling jacrev, internally it uses vmap, and it makes grad_dists to be a batched version, meaning 2D BatchedVector, which makes sense because the vectorized version pushes the for loop into C++ code. And _C.point_face_dist_backward is called with this kind of 2D grad_dists, it throws RuntimeError: Cannot access data pointer of Tensor that doesn't have storage.

Here is a hacked version to verify my point though the math is probably wrong. When using jacrev, "un-vmap" the v_grad_dists and hack it with a for loop to compute row-by-row. It works fine. Then when it returns, pytorch3d complains the returned grad_points has wrong shape.

 @staticmethod
    @once_differentiable
    def backward(ctx, grad_dists, grad_idxs):
        grad_dists = grad_dists.contiguous()
        points, tris, idxs = ctx.saved_tensors
        min_triangle_area = ctx.min_triangle_area

        # https://discuss.pytorch.org/t/save-batchedtensor-to-a-pickle-file/170561/4
        v_points = torch._C._functorch.get_unwrapped(points)
        v_tris = torch._C._functorch.get_unwrapped(tris)
        v_idxs = torch._C._functorch.get_unwrapped(idxs)
        v_grad_dists = torch._C._functorch.get_unwrapped(grad_dists)

        grad_points = []
        grad_tris = None
        for v_grad_dists_v in v_grad_dists:
           v_grad_points, v_grad_tris = _C.point_face_dist_backward(
               v_points, v_tris, v_idxs, v_grad_dists_v, min_triangle_area
           )
           grad_points.append(v_grad_points)
           if grad_tris is not None:
               grad_tris = grad_tris + v_grad_tris
           else:
               grad_tris = v_grad_tris

        grad_points = torch.cat(grad_points, dim=1)
        return grad_points, None, grad_tris, None, None, None

@bottler would you be able to confirm what I said above is correct? And if this is the case, it seems changing internal code of _C.point_face_dist_backward is the only option?

Thanks!

@bottler
Copy link
Contributor

bottler commented Aug 18, 2024

I don't know enough I'm afraid about functorch

@bottler bottler closed this as completed Aug 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants