This section introduces the necessary configuration you need.
- Install the surgical robotics challenge environment as well as the AMBF and ROS prerequisites in the link. It provides simulation environment for suturing phantom combined with da Vinci surgical system.
git clone https://github.com/surgical-robotics-ai/surgical_robotics_challenge
- Install Gymnasium: Gymnasium is a branch and updated version of OpenAI Gym. It provides standard API for the communication between the simulated environment and learning algorithms.
pip install gymnasium
-
Configure the Pytorch and CUDA (if equipped with NVIDIA card) based on your hardware.
-
Install Stable Baseline3 (SB3): SB3 is an open-source Python library providing implementations of state-of-the-art RL algorithms. In this project, it is used to interaction with Gymnasium environment and offering interface for training, evaluating, and testing RL models.
pip install stable-baselines3
- Try to run the code below. This script will create a embedded RL environment from Gymnasium and train it using the PPO algorithm from Stable Baselines3. If everything is set up correctly, the script should run without any errors.
import gymnasium
import stable_baselines3
env = gymnasium.make('CartPole-v1')
model = stable_baselines3.PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
This section introduce the basic procedure for model training with defined Gymnasium environment.
Make sure ROS and SRC is running before moving forward to the following steps. You can simply run the following command or refer to this link for details.
roscore
~/ambf/bin/lin-x86_64/ambf_simulator --launch_file ~/ambf/surgical_robotics_challenge/launch.yaml -l 0,1,3,4,13,14 -p 200 -t 1 --override_max_comm_freq 120
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from Approach_env import SRC_approach
import numpy as np
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from RL_algo.PPO import PPO
from stable_baselines3.common.utils import set_random_seed
gym.envs.register(id="TD3_HER_BC", entry_point=SRC_approach)
env = gym.make("TD3_HER_BC", render_mode="human",reward_type = "sparse")
Here is an example of model with Proximal Policy Optimization (PPO) algorithm (set with default hyperparameters).
model = PPO("MlpPolicy", env, verbose=1,tensorboard_log="./First_version/",)
checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./First_version/Model_temp', name_prefix='SRC')
model.learn(total_timesteps=int(1000000), progress_bar=True,callback=checkpoint_callback,)
model.save("SRC")
model = PPO("MlpPolicy", env, verbose=1,tensorboard_log="./First_version/",)
model_path = "./Model/SRC_10000_steps.zip"
model = PPO.load(model_path)
model.set_env(env=env)
obs,info = env.reset()
print(obs)
for i in range(10000):
action, _state = model.predict(obs, deterministic=True)
print(action)
obs, reward, terminated,truncated, info = env.step(action)
print(info)
env.render()
if terminated or truncated:
obs, info = env.reset()
The following video demonstrates the complete suturing procedure by our training policy.
Demo_part.mp4
If you find our work userful, please cite it as:
@misc{wu2024surgicaifinegrainedplatformdata,
title={SurgicAI: A Fine-grained Platform for Data Collection and Benchmarking in Surgical Policy Learning},
author={Jin Wu and Haoying Zhou and Peter Kazanzides and Adnan Munawar and Anqi Liu},
year={2024},
eprint={2406.13865},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2406.13865},
}