Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] Debugging NaN gradients when try to learn the soft body material like k_mu? #395

Open
zhenqi72 opened this issue Dec 13, 2024 · 1 comment
Labels
question The issue author requires information

Comments

@zhenqi72
Copy link

Hi, I am trying to learn the soft body material, according to the difference between a stand soft body and a learning soft body. The scenario is that a capsule will push a soft body. But I tried lots of ways. Both of them will make the gradient to "nan".

This is my code.

  import math
  import sys
  import os
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  import warp as wp
  import warp.sim
  import warp.sim.render
  import numpy as np
  from warp_test_folder.adam import Adam
  
  import torch   
  
  @wp.kernel
  def loss_calculate(
      pos: wp.array2d(dtype=wp.vec3),
      target_pos: wp.array2d(dtype=wp.vec3),
      loss_element: wp.array2d(dtype=float),
  ):
      i,j = wp.tid()
      loss_element[i][j] = wp.length(target_pos[i][j] - pos[i][j])
  
  @wp.kernel
  def loss_sum(loss_element: wp.array2d(dtype=float), result: wp.array(dtype=float)):
  
      i,j = wp.tid()
      wp.atomic_add(result, 0, loss_element[i][j])
      
  @wp.kernel
  def save_state(particle_q: wp.array(dtype=wp.vec3), write_index: int, last_traj: wp.array2d(dtype=wp.vec3f)):
      i = wp.tid()
      last_traj[write_index][i] = particle_q[i]
      
  @wp.kernel
  def minus(grad: wp.array(dtype=float, ndim=2), orig: wp.array(dtype=float, ndim=2), mid: wp.array(dtype=float, ndim=2)):
      tid = wp.tid()
      mid[tid][0] = orig[tid][0] - grad[tid,0]
      orig[tid][0] = mid[tid][0]
  
  class Example:
      def __init__(self, stage_path="soft_body.usd", num_frames=30):
          self.sim_substeps = 10
          self.num_frames = num_frames
          self.fps = 60
          sim_duration = self.num_frames / self.fps
          self.frame_dt = 1.0 / self.fps
          self.sim_dt = self.frame_dt / self.sim_substeps
          self.sim_time = 0.0
          self.lift_speed = 2.5 / sim_duration * 2.0  # from Smith et al.
          self.rot_speed = math.pi / sim_duration
          self.tape = wp.Tape()
          self.target = torch.load("learn_grad_pos.pt")
          self.ref_traj = wp.from_torch(self.target,requires_grad=True,dtype=wp.vec3f).to('cuda:3')
          self.last_traj = wp.empty_like(self.ref_traj)
          self.loss_element = wp.zeros(shape=(self.ref_traj.shape[0:2]), dtype = float, device='cuda:3', requires_grad=True)
          self.result = wp.zeros(shape=(1,), dtype = float, device='cuda:3', requires_grad=True)
    
          capsule_builder = wp.sim.ModelBuilder() 
          b = capsule_builder.add_body(m=10.0,name="capsule0")#(1.0,1.55, 5.0)
          a = capsule_builder.add_body(m=10.0,name="connect body") # if you has the joint, body position will be decided by joint xform
          capsule_builder.add_shape_capsule(body = b,radius=0.5, half_height=1.0, density=100.0)
  
          capsule_builder.add_joint_prismatic(
              parent=-1,
              child=a,
              axis=(1,0,0),
              target_ke = 400,
              target_kd = 10.0,
              limit_ke = 1000.0,
              limit_kd = 10,
              parent_xform =  wp.transform((0.01,0.50, 0.01), wp.quat_identity()),
              linear_compliance = 0.1,
              #mode = wp.sim.JOINT_MODE_TARGET_POSITION,
              name = "prismatic joint x"
          )
          
          capsule_builder.add_joint_prismatic(
              parent=a,
              child=b,
              axis=(0,0,1),
              target_ke = 400,
              target_kd = 10.0,
              limit_ke = 1000.0,
              limit_kd = 10, 
              parent_xform =  wp.transform((0,-0.01,0), wp.quat_identity()),
              linear_compliance = 0.1,
              #mode = wp.sim.JOINT_MODE_TARGET_POSITION,
              name = "prismatic joint z" 
          )
          
          soft_body_builder = wp.sim.ModelBuilder()
          
          cell_dim = 15
          cell_dim_x = 2
          cell_dim_y = 2 #y is the height
          cell_dim_z = 50
          cell_size = 2.0 / cell_dim
          center = cell_size * cell_dim * 0.5
  
          soft_body_builder.add_soft_grid(
              pos=wp.vec3(-center, cell_size, -center),
              rot=wp.quat_identity(),
              vel=wp.vec3(0.0, 0.0, 0.0),
              dim_x=cell_dim_x,
              dim_y=cell_dim_y,
              dim_z=cell_dim_z,
              cell_x=cell_size,
              cell_y=cell_size,
              cell_z=cell_size,
              fix_bottom=False,
              density=100.0,
              k_mu=45000.0, #, #shear modulus measures the material’s resistance to shear deformation. The volume of the material is conserved
              k_lambda=40000.0, # quantifies the material’s resistance to uniform compression. It relates to the change in volume under pressure
              k_damp=0.0,
          )
  
          builder = wp.sim.ModelBuilder()
          builder.add_builder(capsule_builder)
          builder.add_builder(soft_body_builder)
             
          self.model = builder.finalize(requires_grad=True)
          self.model.ground = True
          self.model.gravity[1] = -9.81
          
          self.integrator = wp.sim.FeatherstoneIntegrator(self.model)
          # self.integrator = wp.sim.SemiImplicitIntegrator()
          self.states = []
          for _ in range(self.num_frames + 1):
              self.states.append(self.model.state())
          # grad array of the whole process
          self.grad_list = []
          # loss of
          self.loss = 0.0
  
          self.optimizer = Adam([self.model.tet_materials], lr=1e2)
  
          if stage_path:
              # Helper to render the physics scene as a USD file.
              self.renderer = warp.sim.render.SimRenderer(self.model, stage_path, fps=self.fps, scaling=10.0)
  
              # Frame number used to render the simulation iterations onto the USD file.
              self.render_frame = 0
          else:
              self.renderer = None
  
  
          #wp.sim.eval_fk(self.model, self.model.particle_q, self.model.particle_qd, None, self.state_0)
  
  
          self.use_cuda_graph = False #wp.get_device().is_cuda
          if self.use_cuda_graph:
              with wp.ScopedCapture() as capture:
                  self.simulate()
              self.graph = capture.graph
  
      def forward(self):
          """
          Advances the system dynamics given the rigid-body state in maximal coordinates and generalized joint torques
          [body_q, body_qd, tau].
          """
  
          self.last_traj.zero_()
  
          for i in range(self.num_frames):
              state = self.states[i]
  
              for _ in range(self.sim_substeps):
                  next_state = self.model.state(requires_grad=True)
                  state.body_f.assign(
                  [
                      [0.0, 0.0, 0.0, 3000.0, 0.0, -3000.0], # here x axis to left, y  to up, z into the screen
                  ]
                  )   
                  wp.sim.collide(self.model, state)
  
                  state = self.integrator.simulate(self.model, state, next_state, self.sim_dt)
  
              self.states[i + 1] = state
  
              # save state
              wp.launch(save_state, dim=self.model.particle_count, inputs=[self.states[i + 1].particle_q, i], outputs=[self.last_traj])
  
          # compute loss
          wp.launch(loss_calculate, 
                    dim=self.last_traj.shape[0:2], 
                    inputs=[self.last_traj, self.ref_traj], 
                    outputs=[self.loss_element])
          wp.launch(loss_sum, 
                    dim=self.last_traj.shape[0:2],  
                    inputs=[self.loss_element], 
                    outputs=[self.result])
          
      def learn(self,train_iter):
          self.tape.zero()
          with self.tape:
              self.forward()
              self.tape.backward(loss = self.result)
          #print(f" {i} th substep loss: {self.result}")  
          print(f"{train_iter} th iteration tet_materials.grad: {self.model.tet_materials.grad}")
          print(f"final loss: {self.result}")
          self.optimizer.step([self.model.tet_materials.grad])
          self.result.zero_()
          
      def step(self,train_iter):
          with wp.ScopedTimer("step"):
              if self.use_cuda_graph:
                  wp.capture_launch(self.graph)
              else:
                  self.learn(train_iter)
  
      def render(self):
          if self.renderer is None:
              return
  
          with wp.ScopedTimer("render"):
              for frame in range(self.num_frames):
                  self.renderer.begin_frame(self.render_frame / self.fps)
                  self.renderer.render(self.states[frame])
                  self.renderer.end_frame()
                  self.render_frame += 1
  
  
  if __name__ == "__main__":
      import argparse
  
      parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
      parser.add_argument("--device", type=str, default='cuda:3', help="Override the default Warp device.")
      parser.add_argument(
          "--stage_path",
          type=lambda x: None if x == "None" else str(x),
          default="soft_cap_grad_learn_copy.usd",
          help="Path to the output USD file."
      )
      parser.add_argument("--num_frames", type=int, default=30, help="Total number of frames.") # 900
      parser.add_argument("--train_iters", type=int, default=150, help="Total number of training iterations.")
      args = parser.parse_known_args()[0]
  
      with wp.ScopedDevice(args.device):
          example = Example(stage_path=args.stage_path, num_frames=args.num_frames)
  
          for iteration in range(args.train_iters):
              print(f"Iteration {iteration}")
              example.step(iteration)
              #example.render()
              example.states = []
              for _ in range(example.num_frames + 1):
                  example.states.append(example.model.state())
  
          if example.renderer:
              example.renderer.save()
@zhenqi72 zhenqi72 added the question The issue author requires information label Dec 13, 2024
@shi-eric shi-eric changed the title [QUESTION] <title> If anyone will get a Nan graident when try to learn the soft body material like k_mu? [QUESTION] Debugging \with NaN gradients when try to learn the soft body material like k_mu? Dec 13, 2024
@shi-eric shi-eric changed the title [QUESTION] Debugging \with NaN gradients when try to learn the soft body material like k_mu? [QUESTION] Debugging NaN gradients when try to learn the soft body material like k_mu? Dec 13, 2024
@Jianghanxiao
Copy link

Some suggestions based on my experience, you'd better make sure all intermediate variables are saved somewhere. Especially in your substep loop, you don't save the next state explicitly. After you modifying the code, you can check the gradient after backward to see which part is broken, which makes life easier

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question The issue author requires information
Projects
None yet
Development

No branches or pull requests

2 participants