Skip to content

Commit

Permalink
Merge pull request #180 from edbeeching/update_il_script
Browse files Browse the repository at this point in the history
Update sb3_imitation.py
  • Loading branch information
Ivan-267 authored Mar 25, 2024
2 parents abac578 + 5819e18 commit ff28cc8
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 49 deletions.
25 changes: 12 additions & 13 deletions docs/IMITATION_LEARNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Download the [sb3 imitation example](/examples/sb3_imitation.py) and `cd` into t
E.g. on Windows:

````
python sb3_imitation.py --env_path="PATH_TO_EXPORTED_GAME_EXE_FILE_HERE" --il_timesteps=250_000 --demo_files="PATH_TO_THE_RECORDED_demo.json_FILE_HERE" --eval_episode_count=20 --n_parallel=5 --speedup=15
python sb3_imitation.py --env_path="PATH_TO_EXPORTED_GAME_EXE_FILE_HERE" --gail_timesteps=250_000 --demo_files="PATH_TO_THE_RECORDED_demo.json_FILE_HERE" --eval_episode_count=20 --n_parallel=5 --speedup=15
````

Training should begin. As we set a small amount of timesteps, the results won't be perfect, but it shouldn't take too
Expand All @@ -230,25 +230,24 @@ After the training is done, an evaluation environment should open, and you will
for
20 episodes.

In my case, I got:
```Mean reward after evaluation: 5.906429767608643```
The exact results you get may be different for various reasons, including the possibility that the hyperparameters
and/or other variables may have changed since then.
Note: If the results are worse than expected, consider opening the Python script and adjusting the hyperparameters for optimal results for your env.

For comparison, when training just with `--rl_timesteps=250_000` I got a reward of:
```Mean reward after evaluation: 9.194426536560059```

The imitation-learned reward could be improved by tweaking hyperparameters (the parameters provided in the script are
not optimized), recording more high quality demos, doing some RL timesteps after it, etc.
The results could be improved by adjusting the hyperparameters and recording more high quality demos.
As this environment was designed and tested with PPO RL, in this case the environment is simple enough that PPO alone
can learn it quickly from the reward function and imitation learning isn't necessary.
However, in more complex environments where it might be difficult to define a good dense reward function, learning from
demonstrations and/or combining it with RL learning from sparse rewards could be helpful.

There are a couple of other options to mention:
You can also try pretraining with BC (Behavioral cloning) before the GAIL training by e.g. adding:
```
--bc_epochs=NUM_EPOCHS (e.g. try 50-200)
--gail_timesteps=250_000
```

After imitation learning, you can continue model training with PPO using the environment rewards to further improve the
results. This is done by adding an argument to the script, e.g. `--rl_timesteps=250_000`.
and/or RL training afterwads by also adding:
```
--rl_timesteps=NUM_STEPS
```

You can set the script to export the trained model to onnx by adding e.g. `--onnx_export_path="model.onnx"`. That model
can be then be copied to the game folder, and set in sync node in testing_scene to be used for inference without the
Expand Down
118 changes: 82 additions & 36 deletions examples/sb3_imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import imitation.data
import numpy as np
from imitation.algorithms import bc
from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util import logger as imit_logger
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
Expand All @@ -27,7 +28,16 @@
"--demo_files",
nargs="+",
type=str,
help="""One or more files with recoded expert demos, with a space in between, e.g. "demo1.json", demo2.json""",
help="""One or more files with recorded expert demos, with a space in between, e.g. --demo_files demo1.json
demo2.json""",
)
parser.add_argument(
"--experiment_name",
default=None,
type=str,
help="The name of the experiment, which will be displayed in tensorboard, logs will be stored in logs/["
"experiment_name], if set. You should use a unique name for every experiment for the tensorboard log to "
"display properly.",
)
parser.add_argument("--seed", type=int, default=0, help="seed of the experiment")
parser.add_argument(
Expand Down Expand Up @@ -63,10 +73,18 @@
help="How many instances of the environment executable to " "launch - requires --env_path to be set if > 1.",
)
parser.add_argument(
"--il_timesteps",
"--bc_epochs",
default=0,
type=int,
help="[Optional] How many epochs to train for using imitation learning with BC. Policy is trained with BC "
"before GAIL or RL.",
)
parser.add_argument(
"--gail_timesteps",
default=0,
type=int,
help="How many timesteps to train for using imitation learning.",
help="[Optional] How many timesteps to train for using imitation learning with GAIL. If --bc_timesteps are set, "
"GAIL training is done after pre-training the policy with BC.",
)
parser.add_argument(
"--rl_timesteps",
Expand Down Expand Up @@ -110,7 +128,6 @@ def close_env():


trajectories = []

for file_path in args.demo_files:
with open(file_path, "r") as file:
data = json.load(file)
Expand All @@ -124,7 +141,10 @@ def close_env():
terminal=True,
)
)

print(
f"Loaded trajectories from {file_path}, found {len(data)} recorded trajectories (GDRL plugin records 1 "
f"episode as 1 trajectory)."
)

env = SBGSingleObsEnv(
env_path=args.env_path,
Expand All @@ -137,47 +157,73 @@ def close_env():

env = VecMonitor(env)


policy_kwargs = dict(log_std_init=log(1.0))

logger = None
if args.experiment_name:
logger = imit_logger.configure(f"logs/{args.experiment_name}", format_strs=["tensorboard", "stdout"])

# The hyperparams are set for IL tutorial env where BC > GAIL training is used. Feel free to customize for
# your usage.
learner = PPO(
batch_size=128,
env=env,
policy="MlpPolicy",
learning_rate=0.0003,
clip_range=0.2,
n_epochs=20,
batch_size=256,
ent_coef=0.007,
learning_rate=0.0002,
n_steps=64,
ent_coef=0.0001,
target_kl=0.025,
target_kl=0.02,
n_epochs=5,
policy_kwargs=policy_kwargs,
verbose=1,
verbose=2,
tensorboard_log=f"logs/{args.experiment_name}",
# seed=args.seed // Not currently supported as stable_baselines_wrapper.py seed() method is not yet implemented.
)

if args.il_timesteps:
reward_net = BasicRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
normalize_input_layer=RunningNorm,
)
try:
if args.bc_epochs > 0:
rng = np.random.default_rng(args.seed)
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=trajectories,
rng=rng,
policy=learner.policy,
custom_logger=logger,
)
print("Starting Imitation Learning Training using BC:")
bc_trainer.train(n_epochs=args.bc_epochs)

if args.gail_timesteps > 0:
print("Starting Imitation Learning Training using GAIL:")
reward_net = BasicRewardNet(
observation_space=env.observation_space,
action_space=env.action_space,
)

gail_trainer = GAIL(
demonstrations=trajectories,
demo_batch_size=128,
gen_replay_buffer_capacity=512,
n_disc_updates_per_round=24,
venv=env,
gen_algo=learner,
reward_net=reward_net,
allow_variable_horizon=True,
)
gail_trainer = GAIL(
demonstrations=trajectories,
demo_batch_size=256,
n_disc_updates_per_round=16,
venv=env,
gen_algo=learner,
reward_net=reward_net,
allow_variable_horizon=True,
init_tensorboard=True,
init_tensorboard_graph=True,
custom_logger=logger,
)
gail_trainer.train(args.gail_timesteps)

print("Starting Imitation Learning Training using GAIL:")
gail_trainer.train(args.il_timesteps)
if args.rl_timesteps > 0:
print("Starting RL Training:")
learner.learn(args.rl_timesteps, progress_bar=True)

if args.rl_timesteps:
print("Starting RL Training:")
learner.learn(args.rl_timesteps, progress_bar=True)
except KeyboardInterrupt:
print(
"""Training interrupted by user. Will save if --save_model_path was
used and/or export if --onnx_export_path was used."""
)

close_env()

Expand All @@ -192,8 +238,8 @@ def close_env():
)
env = VecMonitor(env)
mean_reward, _ = evaluate_policy(learner, env, n_eval_episodes=args.eval_episode_count)
close_env()
print(f"Mean reward after evaluation: {mean_reward}")

close_env()
handle_onnx_export()
handle_model_save()

0 comments on commit ff28cc8

Please sign in to comment.