Skip to content

Commit

Permalink
update autobot.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Alan-LanFeng committed Nov 26, 2024
1 parent 10048b3 commit 24811ce
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions unitraj/models/autobot/autobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, d_k, map_attr=3, dropout=0.1):
def get_road_pts_mask(self, roads):
road_segment_mask = torch.sum(roads[:, :, :, -1], dim=2) == 0
road_pts_mask = (1.0 - roads[:, :, :, -1]).type(torch.BoolTensor).to(roads.device).view(-1, roads.shape[2])
road_pts_mask[:, 0][road_pts_mask.sum(-1) == roads.shape[2]] = False # Ensures no NaNs due to empty rows.
road_pts_mask = road_pts_mask.masked_fill((road_pts_mask.sum(-1) == roads.shape[2]).unsqueeze(-1), False) # Ensures no NaNs due to empty rows.
return road_segment_mask, road_pts_mask

def forward(self, roads, agents_emb):
Expand Down Expand Up @@ -272,7 +272,7 @@ def temporal_attn_fn(self, agents_emb, agent_masks, layer):
B = agent_masks.size(0)
num_agents = agent_masks.size(2)
temp_masks = agent_masks.permute(0, 2, 1).reshape(-1, T_obs)
temp_masks[:, -1][temp_masks.sum(-1) == T_obs] = False # Ensure that agent's that don't exist don't make NaN.
temp_masks = temp_masks.masked_fill((temp_masks.sum(-1) == T_obs).unsqueeze(-1), False)
agents_temp_emb = layer(self.pos_encoder(agents_emb.reshape(T_obs, B * (num_agents), -1)),
src_key_padding_mask=temp_masks)
return agents_temp_emb.view(T_obs, B, num_agents, -1)
Expand Down Expand Up @@ -346,9 +346,7 @@ def _forward(self, inputs):
output['predicted_probability'] = mode_probs # #[B, c]
output['predicted_trajectory'] = out_dists.permute(2, 0, 1,
3) # [c, T, B, 5] to [B, c, T, 5] to be able to parallelize code
# output['scene_emb'] = mode_params_emb.transpose(0,1).reshape(B,-1)
if len(np.argwhere(np.isnan(out_dists.detach().cpu().numpy()))) > 1:
breakpoint()

return output

def forward(self, batch):
Expand All @@ -370,8 +368,11 @@ def forward(self, batch):
model_input['agents_in'] = agents_in
model_input['roads'] = roads
output = self._forward(model_input)

loss = self.get_loss(batch, output)
# if self.training:
# loss = self.get_loss(batch, output)
# else:
# loss = 0

return output, loss

Expand Down Expand Up @@ -407,17 +408,18 @@ def get_BVG_distributions(self, pred):
sigma_y = pred[:, :, 3]
rho = pred[:, :, 4]

cov = torch.zeros((B, T, 2, 2)).to(pred.device)
cov[:, :, 0, 0] = sigma_x ** 2
cov[:, :, 1, 1] = sigma_y ** 2
cov[:, :, 0, 1] = rho * sigma_x * sigma_y
cov[:, :, 1, 0] = rho * sigma_x * sigma_y
# Create the base covariance matrix for a single element
cov = torch.stack([
torch.stack([sigma_x ** 2, rho * sigma_x * sigma_y], dim=-1),
torch.stack([rho * sigma_x * sigma_y, sigma_y ** 2], dim=-1)
], dim=-2)

biv_gauss_dist = MultivariateNormal(loc=torch.cat((mu_x, mu_y), dim=-1), covariance_matrix=cov)
# Expand this base matrix to match the desired shape
biv_gauss_dist = MultivariateNormal(loc=torch.cat((mu_x, mu_y), dim=-1), covariance_matrix=cov,validate_args=False)
return biv_gauss_dist

def get_Laplace_dist(self, pred):
return Laplace(pred[:, :, :2], pred[:, :, 2:4])
return Laplace(pred[:, :, :2], pred[:, :, 2:4],validate_args=False)

def nll_pytorch_dist(self, pred, data, mask, rtn_loss=True):
# biv_gauss_dist = get_BVG_distributions(pred)
Expand Down Expand Up @@ -452,18 +454,27 @@ def nll_loss_multimodes(self, output, data, center_gt_final_valid_idx):
modes = len(pred)
nSteps, batch_sz, dim = pred[0].shape

# compute posterior probability based on predicted prior and likelihood of predicted trajectory.
log_lik = np.zeros((batch_sz, modes))
log_lik_list = []
with torch.no_grad():
for kk in range(modes):
nll = self.nll_pytorch_dist(pred[kk].transpose(0, 1), data, mask, rtn_loss=False)
log_lik[:, kk] = -nll.cpu().numpy()
log_lik_list.append(-nll.unsqueeze(1)) # Add a new dimension to concatenate later

# Concatenate the list to form the log_lik tensor
log_lik = torch.cat(log_lik_list, dim=1)

priors = modes_pred
log_priors = torch.log(priors)
log_posterior_unnorm = log_lik + log_priors

priors = modes_pred.detach().cpu().numpy()
log_posterior_unnorm = log_lik + np.log(priors)
log_posterior = log_posterior_unnorm - special.logsumexp(log_posterior_unnorm, axis=-1).reshape((batch_sz, -1))
post_pr = np.exp(log_posterior)
post_pr = torch.tensor(post_pr).float().to(data.device)
# Compute logsumexp for normalization, ensuring no in-place operations
logsumexp = torch.logsumexp(log_posterior_unnorm, dim=-1, keepdim=True)
log_posterior = log_posterior_unnorm - logsumexp

# Compute the posterior probabilities without in-place operations
post_pr = torch.exp(log_posterior)
# Ensure post_pr is a tensor on the correct device
post_pr = post_pr.to(data.device)

# Compute loss.
loss = 0.0
Expand Down Expand Up @@ -502,3 +513,5 @@ def l2_loss_fde(self, pred, data, mask):
dim=-1) * mask.unsqueeze(0)).mean(dim=2).transpose(0, 1)
loss, min_inds = (fde_loss + ade_loss).min(dim=1)
return 100.0 * loss.mean()


0 comments on commit 24811ce

Please sign in to comment.