From 24811ce9b5cb1f63e1bf1fb2ce8ccb916e75e33c Mon Sep 17 00:00:00 2001 From: Alan Date: Tue, 26 Nov 2024 14:25:22 +0100 Subject: [PATCH] update autobot.py --- unitraj/models/autobot/autobot.py | 55 +++++++++++++++++++------------ 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/unitraj/models/autobot/autobot.py b/unitraj/models/autobot/autobot.py index fde7d18..46c9d2b 100644 --- a/unitraj/models/autobot/autobot.py +++ b/unitraj/models/autobot/autobot.py @@ -73,7 +73,7 @@ def __init__(self, d_k, map_attr=3, dropout=0.1): def get_road_pts_mask(self, roads): road_segment_mask = torch.sum(roads[:, :, :, -1], dim=2) == 0 road_pts_mask = (1.0 - roads[:, :, :, -1]).type(torch.BoolTensor).to(roads.device).view(-1, roads.shape[2]) - road_pts_mask[:, 0][road_pts_mask.sum(-1) == roads.shape[2]] = False # Ensures no NaNs due to empty rows. + road_pts_mask = road_pts_mask.masked_fill((road_pts_mask.sum(-1) == roads.shape[2]).unsqueeze(-1), False) # Ensures no NaNs due to empty rows. return road_segment_mask, road_pts_mask def forward(self, roads, agents_emb): @@ -272,7 +272,7 @@ def temporal_attn_fn(self, agents_emb, agent_masks, layer): B = agent_masks.size(0) num_agents = agent_masks.size(2) temp_masks = agent_masks.permute(0, 2, 1).reshape(-1, T_obs) - temp_masks[:, -1][temp_masks.sum(-1) == T_obs] = False # Ensure that agent's that don't exist don't make NaN. + temp_masks = temp_masks.masked_fill((temp_masks.sum(-1) == T_obs).unsqueeze(-1), False) agents_temp_emb = layer(self.pos_encoder(agents_emb.reshape(T_obs, B * (num_agents), -1)), src_key_padding_mask=temp_masks) return agents_temp_emb.view(T_obs, B, num_agents, -1) @@ -346,9 +346,7 @@ def _forward(self, inputs): output['predicted_probability'] = mode_probs # #[B, c] output['predicted_trajectory'] = out_dists.permute(2, 0, 1, 3) # [c, T, B, 5] to [B, c, T, 5] to be able to parallelize code - # output['scene_emb'] = mode_params_emb.transpose(0,1).reshape(B,-1) - if len(np.argwhere(np.isnan(out_dists.detach().cpu().numpy()))) > 1: - breakpoint() + return output def forward(self, batch): @@ -370,8 +368,11 @@ def forward(self, batch): model_input['agents_in'] = agents_in model_input['roads'] = roads output = self._forward(model_input) - loss = self.get_loss(batch, output) + # if self.training: + # loss = self.get_loss(batch, output) + # else: + # loss = 0 return output, loss @@ -407,17 +408,18 @@ def get_BVG_distributions(self, pred): sigma_y = pred[:, :, 3] rho = pred[:, :, 4] - cov = torch.zeros((B, T, 2, 2)).to(pred.device) - cov[:, :, 0, 0] = sigma_x ** 2 - cov[:, :, 1, 1] = sigma_y ** 2 - cov[:, :, 0, 1] = rho * sigma_x * sigma_y - cov[:, :, 1, 0] = rho * sigma_x * sigma_y + # Create the base covariance matrix for a single element + cov = torch.stack([ + torch.stack([sigma_x ** 2, rho * sigma_x * sigma_y], dim=-1), + torch.stack([rho * sigma_x * sigma_y, sigma_y ** 2], dim=-1) + ], dim=-2) - biv_gauss_dist = MultivariateNormal(loc=torch.cat((mu_x, mu_y), dim=-1), covariance_matrix=cov) + # Expand this base matrix to match the desired shape + biv_gauss_dist = MultivariateNormal(loc=torch.cat((mu_x, mu_y), dim=-1), covariance_matrix=cov,validate_args=False) return biv_gauss_dist def get_Laplace_dist(self, pred): - return Laplace(pred[:, :, :2], pred[:, :, 2:4]) + return Laplace(pred[:, :, :2], pred[:, :, 2:4],validate_args=False) def nll_pytorch_dist(self, pred, data, mask, rtn_loss=True): # biv_gauss_dist = get_BVG_distributions(pred) @@ -452,18 +454,27 @@ def nll_loss_multimodes(self, output, data, center_gt_final_valid_idx): modes = len(pred) nSteps, batch_sz, dim = pred[0].shape - # compute posterior probability based on predicted prior and likelihood of predicted trajectory. - log_lik = np.zeros((batch_sz, modes)) + log_lik_list = [] with torch.no_grad(): for kk in range(modes): nll = self.nll_pytorch_dist(pred[kk].transpose(0, 1), data, mask, rtn_loss=False) - log_lik[:, kk] = -nll.cpu().numpy() + log_lik_list.append(-nll.unsqueeze(1)) # Add a new dimension to concatenate later + + # Concatenate the list to form the log_lik tensor + log_lik = torch.cat(log_lik_list, dim=1) + + priors = modes_pred + log_priors = torch.log(priors) + log_posterior_unnorm = log_lik + log_priors - priors = modes_pred.detach().cpu().numpy() - log_posterior_unnorm = log_lik + np.log(priors) - log_posterior = log_posterior_unnorm - special.logsumexp(log_posterior_unnorm, axis=-1).reshape((batch_sz, -1)) - post_pr = np.exp(log_posterior) - post_pr = torch.tensor(post_pr).float().to(data.device) + # Compute logsumexp for normalization, ensuring no in-place operations + logsumexp = torch.logsumexp(log_posterior_unnorm, dim=-1, keepdim=True) + log_posterior = log_posterior_unnorm - logsumexp + + # Compute the posterior probabilities without in-place operations + post_pr = torch.exp(log_posterior) + # Ensure post_pr is a tensor on the correct device + post_pr = post_pr.to(data.device) # Compute loss. loss = 0.0 @@ -502,3 +513,5 @@ def l2_loss_fde(self, pred, data, mask): dim=-1) * mask.unsqueeze(0)).mean(dim=2).transpose(0, 1) loss, min_inds = (fde_loss + ade_loss).min(dim=1) return 100.0 * loss.mean() + +