Skip to content

[ICLR 2025] Code for "Neuralized Markov Random Field for Interaction-Aware Stochastic Human Trajectory Prediction"

Notifications You must be signed in to change notification settings

AdaCompNUS/NMRF_TrajectoryPrediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Neuralized Markov Random Field for Interaction-Aware Stochastic Human Trajectory Prediction

PyTorch implementation for ICLR'25 paper Neuralized Markov Random Field for Interaction-Aware Stochastic Human Trajectory Prediction.

1. Overview

Abstract. Interactive human motions and the continuously changing nature of intentions pose significant challenges for human trajectory prediction. In this paper, we present a neuralized Markov random field (MRF)-based motion evolution method for probabilistic interaction-aware human trajectory prediction. We use MRF to model each agent's motion and the resulting crowd interactions over time, hence is robust against noisy observations and enables group reasoning. We approximate the modeled distribution using two conditional variational autoencoders (CVAEs) for efficient learning and inference. Our proposed method achieves state-of-the-art performance on ADE/FDE metrics across two dataset categories: overhead datasets ETH/UCY, SDD, and NBA, and ego-centric JRDB. Furthermore, our approach allows for real-time stochastic inference in bustling environments, making it well-suited for a 30FPS video setting.

2. Pytorch Implementation

Train logs and all datasets are contained in the ./logs and ./processed_datasets folders. Please download the pretrained models from Google Drive.

2.1 Environment Setup

We train and evaluate our model on Ubuntu 20.04 with one Quadro RTX 8000 GPU, and the conda environment is configured with python=3.7.16 and torch=1.8.0.

The codes have also been tested with torch=1.11.0 and torch=2.5.1, and both versions work.

2.2 Datasets

ETH-UCY. The raw data is referenced from SR-LSTM and STAR, and we include it under ./preprocess/raw_data/. The train-validation-test splits adhere to the original Social-GAN paper.

SDD. The dataset is referenced from SocialVAE, where they provide SDD in meters. We include the raw data under ./preprocess/raw_data/sdd/

NBA. We use the same data and train-test splits provided by the previous SOTA model LED.

JRDB. Please refer to the official website to register and download the dataset. They provide images, point clouds, ROS bags, etc., for usage. An official toolkit is available to preprocess the raw data for different tasks. We mainly use the tracking_eval part to obtain trajectory sequences.

Since registration and login are required to download JRDB, we do not include any raw data here. A preprocessed version for the stochastic prediction task (N=20) is provided in ./processed_datasets. If this also raises copyright concerns, please contact us, and we will remove it.

2.3 Training

We include all processed datasets in the ./processed_datasets folder, except for the NBA dataset, as it is large. We attach it to the same Google Drive link as the pre-trained model. Download the nba.zip file, place it in the ./processed_datasets folder, and unzip it.

If you're interested, feel free to replicate the data preprocessing yourself:

cd preprocess
python prepare_ethucy.py && cd ..

As described, the training process consists of two stages: CVAE Training and Sampler Training. Use the following command to start the first stage, where you can freely specify the log name with the argument --log

python main.py --dataset nba --train --log default  # 'default' can be replaced by any other name you like

Then, choose the best checkpoint based on the validation results and start the second stage. For example, if epoch 150 is the best, you will run:

python main.py --dataset nba --train --log default --use_sampler --epoch 150

Train logs are saved in the ./logs folder, and checkpoints are automatically saved in the ./results folder, which is created by the code.

2.4 Testing

Simply run the command:

python main.py --dataset XXX --log XXX --use_sampler --epoch XXX

Remember there is no --train argument.

If you would like to test our pre-trained model, download and unzip the results.zip file in the main folder, and run:

# For ETH-UCY dataset
python main.py --dataset eth --log default --use_sampler --epoch 47
python main.py --dataset hotel --log default --use_sampler --epoch 44
python main.py --dataset univ --log default --use_sampler --epoch 15
python main.py --dataset zara1 --log default --use_sampler --epoch 47
python main.py --dataset zara2 --log default --use_sampler --epoch 55

# For SDD dataset
python main.py --dataset sdd --log default --use_sampler --epoch 80

# For NBA dataset
python main.py --dataset nba --log default --use_sampler --epoch 14

# For JRDB dataset
python main.py --dataset jrdb --log default --use_sampler --epoch 49

With the pre-trained models, you will get:

Dataset ADE FDE
ETH 0.2531 0.3746
HOTEL 0.1125 0.1683
UNIV 0.2763 0.4872
ZARA1 0.1806 0.3040
ZARA2 0.1405 0.2475
SDD 7.0960 11.1058
# For NBA dataset
--ADE(1s): 0.1646       --FDE(1s): 0.2404
--ADE(2s): 0.3376       --FDE(2s): 0.5006
--ADE(3s): 0.5338       --FDE(3s): 0.7433
--ADE(4s): 0.7500       --FDE(4s): 0.9654

# For JRDB dataset
--ADE(1.2s): 0.0430     --FDE(1.2s): 0.0539
--ADE(2.4s): 0.0769     --FDE(2.4s): 0.1099
--ADE(3.6s): 0.1134     --FDE(3.6s): 0.1672
--ADE(4.8s): 0.1516     --FDE(4.8s): 0.2256

Citation

If you find this repo useful, please consider citing our paper as:

@inproceedings{
fang2025neuralized,
title={Neuralized Markov Random Field for Interaction-Aware Stochastic Human Trajectory Prediction},
author={Zilin Fang and David Hsu and Gim Hee Lee},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025}
}

Reference

The code base heavily borrows from LED, with the discrepancy loss function referring to NPSN and the data preprocessing part for ETH-UCY and SDD referring to Social-STGCNN.

About

[ICLR 2025] Code for "Neuralized Markov Random Field for Interaction-Aware Stochastic Human Trajectory Prediction"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages