-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcommon.py
766 lines (671 loc) · 30.9 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
"""Core code for adversarial imitation learning, shared between GAIL and AIRL."""
import abc
import dataclasses
import logging
import os # MODIFIED:
from datetime import datetime # MODIFIED:
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload
import numpy as np
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from imitation.algorithms import base
from imitation.data import buffer, rollout, types, wrappers
from imitation.rewards import reward_nets, reward_wrapper
from imitation.util import logger, networks, util
from stable_baselines3.common import (
base_class,
distributions,
on_policy_algorithm,
policies,
vec_env,
)
from stable_baselines3.common.evaluation import evaluate_policy # MODIFIED:
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F
from src.rl.utils import save_model # MODIFIED:
# MODIFIED: Define a random number generator
RNG = np.random.default_rng(1)
def compute_train_stats(
disc_logits_expert_is_high: th.Tensor,
labels_expert_is_one: th.Tensor,
disc_loss: th.Tensor,
) -> Mapping[str, float]:
"""Train statistics for GAIL/AIRL discriminator.
Args:
disc_logits_expert_is_high: discriminator logits produced by
`AdversarialTrainer.logits_expert_is_high`.
labels_expert_is_one: integer labels describing whether logit was for an
expert (0) or generator (1) sample.
disc_loss: final discriminator loss.
Returns:
A mapping from statistic names to float values.
"""
with th.no_grad():
# Logits of the discriminator output; >0 for expert samples, <0 for generator.
bin_is_generated_pred = disc_logits_expert_is_high < 0
# Binary label, so 1 is for expert, 0 is for generator.
bin_is_generated_true = labels_expert_is_one == 0
bin_is_expert_true = th.logical_not(bin_is_generated_true)
int_is_generated_pred = bin_is_generated_pred.long()
int_is_generated_true = bin_is_generated_true.long()
n_generated = float(th.sum(int_is_generated_true))
n_labels = float(len(labels_expert_is_one))
n_expert = n_labels - n_generated
pct_expert = (
n_expert / float(n_labels) if n_labels > 0 else float("NaN")
)
n_expert_pred = int(n_labels - th.sum(int_is_generated_pred))
if n_labels > 0:
pct_expert_pred = n_expert_pred / float(n_labels)
else:
pct_expert_pred = float("NaN")
correct_vec = th.eq(bin_is_generated_pred, bin_is_generated_true)
acc = th.mean(correct_vec.float())
_n_pred_expert = th.sum(th.logical_and(bin_is_expert_true, correct_vec))
if n_expert < 1:
expert_acc = float("NaN")
else:
# float() is defensive, since we cannot divide Torch tensors by
# Python ints
expert_acc = _n_pred_expert.item() / float(n_expert)
_n_pred_gen = th.sum(th.logical_and(bin_is_generated_true, correct_vec))
_n_gen_or_1 = max(1, n_generated)
generated_acc = _n_pred_gen / float(_n_gen_or_1)
label_dist = th.distributions.Bernoulli(
logits=disc_logits_expert_is_high
)
entropy = th.mean(label_dist.entropy())
return {
"disc_loss": float(th.mean(disc_loss)),
"disc_acc": float(acc),
"disc_acc_expert": float(
expert_acc
), # accuracy on just expert examples
"disc_acc_gen": float(
generated_acc
), # accuracy on just generated examples
# entropy of the predicted label distribution, averaged equally across
# both classes (if this drops then disc is very good or has given up)
"disc_entropy": float(entropy),
# true number of expert demos and predicted number of expert demos
"disc_proportion_expert_true": float(pct_expert),
"disc_proportion_expert_pred": float(pct_expert_pred),
"n_expert": float(n_expert),
"n_generated": float(n_generated),
}
class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]):
"""Base class for adversarial imitation learning algorithms like GAIL and AIRL."""
venv: vec_env.VecEnv
"""The original vectorized environment."""
venv_train: vec_env.VecEnv
"""Like `self.venv`, but wrapped with train reward unless in debug mode.
If `debug_use_ground_truth=True` was passed into the initializer then
`self.venv_train` is the same as `self.venv`."""
_demo_data_loader: Optional[Iterable[types.TransitionMapping]]
_endless_expert_iterator: Optional[Iterator[types.TransitionMapping]]
venv_wrapped: vec_env.VecEnvWrapper
def __init__(
self,
*,
demonstrations: base.AnyTransitions,
demo_batch_size: int,
venv: vec_env.VecEnv,
gen_algo: base_class.BaseAlgorithm,
reward_net: reward_nets.RewardNet,
demo_minibatch_size: Optional[int] = None,
n_disc_updates_per_round: int = 2,
log_dir: types.AnyPath = "output/",
disc_opt_cls: Type[th.optim.Optimizer] = th.optim.Adam,
disc_opt_kwargs: Optional[Mapping] = None,
gen_train_timesteps: Optional[int] = None,
gen_replay_buffer_capacity: Optional[int] = None,
custom_logger: Optional[logger.HierarchicalLogger] = None,
init_tensorboard: bool = False,
init_tensorboard_graph: bool = False,
debug_use_ground_truth: bool = False,
allow_variable_horizon: bool = False,
):
"""Builds AdversarialTrainer.
Args:
demonstrations: Demonstrations from an expert (optional). Transitions
expressed directly as a `types.TransitionsMinimal` object, a sequence
of trajectories, or an iterable of transition batches (mappings from
keywords to arrays containing observations, etc).
demo_batch_size: The number of samples in each batch of expert data. The
discriminator batch size is twice this number because each discriminator
batch contains a generator sample for every expert sample.
venv: The vectorized environment to train in.
gen_algo: The generator RL algorithm that is trained to maximize
discriminator confusion. Environment and logger will be set to
`venv` and `custom_logger`.
reward_net: a Torch module that takes an observation, action and
next observation tensors as input and computes a reward signal.
demo_minibatch_size: size of minibatch to calculate gradients over.
The gradients are accumulated until the entire batch is
processed before making an optimization step. This is
useful in GPU training to reduce memory usage, since
fewer examples are loaded into memory at once,
facilitating training with larger batch sizes, but is
generally slower. Must be a factor of `demo_batch_size`.
Optional, defaults to `demo_batch_size`.
n_disc_updates_per_round: The number of discriminator updates after each
round of generator updates in AdversarialTrainer.learn().
log_dir: Directory to store TensorBoard logs, plots, etc. in.
disc_opt_cls: The optimizer for discriminator training.
disc_opt_kwargs: Parameters for discriminator training.
gen_train_timesteps: The number of steps to train the generator policy for
each iteration. If None, then defaults to the batch size (for on-policy)
or number of environments (for off-policy).
gen_replay_buffer_capacity: The capacity of the
generator replay buffer (the number of obs-action-obs samples from
the generator that can be stored). By default this is equal to
`gen_train_timesteps`, meaning that we sample only from the most
recent batch of generator samples.
custom_logger: Where to log to; if None (default), creates a new logger.
init_tensorboard: If True, makes various discriminator
TensorBoard summaries.
init_tensorboard_graph: If both this and `init_tensorboard` are True,
then write a Tensorboard graph summary to disk.
debug_use_ground_truth: If True, use the ground truth reward for
`self.train_env`.
This disables the reward wrapping that would normally replace
the environment reward with the learned reward. This is useful for
sanity checking that the policy training is functional.
allow_variable_horizon: If False (default), algorithm will raise an
exception if it detects trajectories of different length during
training. If True, overrides this safety check. WARNING: variable
horizon episodes leak information about the reward via termination
condition, and can seriously confound evaluation. Read
https://imitation.readthedocs.io/en/latest/guide/variable_horizon.html
before overriding this.
Raises:
ValueError: if the batch size is not a multiple of the minibatch size.
"""
self.demo_batch_size = demo_batch_size
self.demo_minibatch_size = demo_minibatch_size or demo_batch_size
if self.demo_batch_size % self.demo_minibatch_size != 0:
raise ValueError("Batch size must be a multiple of minibatch size.")
self._demo_data_loader = None
self._endless_expert_iterator = None
super().__init__(
demonstrations=demonstrations,
custom_logger=custom_logger,
allow_variable_horizon=allow_variable_horizon,
)
self._global_step = 0
self._disc_step = 0
self.n_disc_updates_per_round = n_disc_updates_per_round
self.debug_use_ground_truth = debug_use_ground_truth
self.venv = venv
self.gen_algo = gen_algo
self._reward_net = reward_net.to(gen_algo.device)
self._log_dir = util.parse_path(log_dir)
# Create graph for optimising/recording stats on discriminator
self._disc_opt_cls = disc_opt_cls
self._disc_opt_kwargs = disc_opt_kwargs or {}
self._init_tensorboard = init_tensorboard
self._init_tensorboard_graph = init_tensorboard_graph
self._disc_opt = self._disc_opt_cls(
self._reward_net.parameters(),
**self._disc_opt_kwargs,
)
if self._init_tensorboard:
logging.info(f"building summary directory at {self._log_dir}")
summary_dir = self._log_dir / "summary"
summary_dir.mkdir(parents=True, exist_ok=True)
self._summary_writer = thboard.SummaryWriter(str(summary_dir))
self.venv_buffering = wrappers.BufferingWrapper(self.venv)
if debug_use_ground_truth:
# Would use an identity reward fn here, but RewardFns can't see rewards.
self.venv_wrapped = self.venv_buffering
self.gen_callback = None
else:
self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper(
self.venv_buffering,
reward_fn=self.reward_train.predict_processed,
)
self.gen_callback = self.venv_wrapped.make_log_callback()
self.venv_train = self.venv_wrapped
self.gen_algo.set_env(self.venv_train)
self.gen_algo.set_logger(self.logger)
if gen_train_timesteps is None:
gen_algo_env = self.gen_algo.get_env()
assert gen_algo_env is not None
self.gen_train_timesteps = gen_algo_env.num_envs
if isinstance(self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm):
self.gen_train_timesteps *= self.gen_algo.n_steps
else:
self.gen_train_timesteps = gen_train_timesteps
if gen_replay_buffer_capacity is None:
gen_replay_buffer_capacity = self.gen_train_timesteps
self._gen_replay_buffer = buffer.ReplayBuffer(
gen_replay_buffer_capacity,
self.venv,
)
# MODIFIED: Model saving parameters
self.ts_now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.save_path = os.path.join(os.getcwd(), "data", "models")
self.highest_reward = -np.inf
@property
def policy(self) -> policies.BasePolicy:
policy = self.gen_algo.policy
assert policy is not None
return policy
@abc.abstractmethod
def logits_expert_is_high(
self,
state: th.Tensor,
action: th.Tensor,
next_state: th.Tensor,
done: th.Tensor,
log_policy_act_prob: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Compute the discriminator's logits for each state-action sample.
A high value corresponds to predicting expert, and a low value corresponds to
predicting generator.
Args:
state: state at time t, of shape `(batch_size,) + state_shape`.
action: action taken at time t, of shape `(batch_size,) + action_shape`.
next_state: state at time t+1, of shape `(batch_size,) + state_shape`.
done: binary episode completion flag after action at time t,
of shape `(batch_size,)`.
log_policy_act_prob: log probability of generator policy taking
`action` at time t.
Returns:
Discriminator logits of shape `(batch_size,)`. A high output indicates an
expert-like transition.
""" # noqa: DAR202
@property
@abc.abstractmethod
def reward_train(self) -> reward_nets.RewardNet:
"""Reward used to train generator policy."""
@property
@abc.abstractmethod
def reward_test(self) -> reward_nets.RewardNet:
"""Reward used to train policy at "test" time after adversarial training."""
def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
self._demo_data_loader = base.make_data_loader(
demonstrations,
self.demo_batch_size,
)
self._endless_expert_iterator = util.endless_iter(
self._demo_data_loader
)
def _next_expert_batch(self) -> Mapping:
assert self._endless_expert_iterator is not None
return next(self._endless_expert_iterator)
def train_disc(
self,
*,
expert_samples: Optional[Mapping] = None,
gen_samples: Optional[Mapping] = None,
) -> Mapping[str, float]:
"""Perform a single discriminator update, optionally using provided samples.
Args:
expert_samples: Transition samples from the expert in dictionary form.
If provided, must contain keys corresponding to every field of the
`Transitions` dataclass except "infos". All corresponding values can be
either NumPy arrays or Tensors. Extra keys are ignored. Must contain
`self.demo_batch_size` samples. If this argument is not provided, then
`self.demo_batch_size` expert samples from `self.demo_data_loader` are
used by default.
gen_samples: Transition samples from the generator policy in same dictionary
form as `expert_samples`. If provided, must contain exactly
`self.demo_batch_size` samples. If not provided, then take
`len(expert_samples)` samples from the generator replay buffer.
Returns:
Statistics for discriminator (e.g. loss, accuracy).
"""
with self.logger.accumulate_means("disc"):
# optionally write TB summaries for collected ops
write_summaries = (
self._init_tensorboard and self._global_step % 20 == 0
)
# compute loss
self._disc_opt.zero_grad()
batch_iter = self._make_disc_train_batches(
gen_samples=gen_samples,
expert_samples=expert_samples,
)
for batch in batch_iter:
disc_logits = self.logits_expert_is_high(
batch["state"],
batch["action"],
batch["next_state"],
batch["done"],
batch["log_policy_act_prob"],
)
loss = F.binary_cross_entropy_with_logits(
disc_logits,
batch["labels_expert_is_one"].float(),
)
# Renormalise the loss to be averaged over the whole
# batch size instead of the minibatch size.
assert len(batch["state"]) == 2 * self.demo_minibatch_size
loss *= self.demo_minibatch_size / self.demo_batch_size
loss.backward()
# MODIFIED: Update the discriminator
self._disc_opt.step()
self._disc_opt.zero_grad()
# do gradient step
# self._disc_opt.step()
self._disc_step += 1
# compute/write stats and TensorBoard data
with th.no_grad():
train_stats = compute_train_stats(
disc_logits,
batch["labels_expert_is_one"],
loss,
)
self.logger.record("global_step", self._global_step)
for k, v in train_stats.items():
self.logger.record(k, v)
self.logger.dump(self._disc_step)
if write_summaries:
self._summary_writer.add_histogram(
"disc_logits", disc_logits.detach()
)
return train_stats
def train_gen(
self,
total_timesteps: Optional[int] = None,
learn_kwargs: Optional[Mapping] = None,
) -> None:
"""Trains the generator to maximize the discriminator loss.
After the end of training populates the generator replay buffer (used in
discriminator training) with `self.disc_batch_size` transitions.
Args:
total_timesteps: The number of transitions to sample from
`self.venv_train` during training. By default,
`self.gen_train_timesteps`.
learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
method.
"""
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps
if learn_kwargs is None:
learn_kwargs = {}
with self.logger.accumulate_means("gen"):
self.gen_algo.learn(
total_timesteps=total_timesteps,
reset_num_timesteps=False,
callback=self.gen_callback,
**learn_kwargs,
)
self._global_step += 1
gen_trajs, ep_lens = self.venv_buffering.pop_trajectories()
self._check_fixed_horizon(ep_lens)
gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs)
self._gen_replay_buffer.store(gen_samples)
def train(
self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
) -> None:
"""Alternates between training the generator and discriminator.
Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
a call to `train_disc`, and finally a call to `callback(round)`.
Training ends once an additional "round" would cause the number of transitions
sampled from the environment to exceed `total_timesteps`.
Args:
total_timesteps: An upper bound on the number of transitions to sample
from the environment during training.
callback: A function called at the end of every round which takes in a
single argument, the round number. Round numbers are in
`range(total_timesteps // self.gen_train_timesteps)`.
"""
n_rounds = total_timesteps // self.gen_train_timesteps
assert n_rounds >= 1, (
"No updates (need at least "
f"{self.gen_train_timesteps} timesteps, have only "
f"total_timesteps={total_timesteps})!"
)
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
# MODIFIED: Turn on reward scaling
self._reward_net.scale = True
self.train_gen(self.gen_train_timesteps)
# MODIFIED: Turn off reward scaling
self._reward_net.scale = False
for _ in range(self.n_disc_updates_per_round):
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
if callback:
callback(r)
self.logger.dump(self._global_step)
# Evaluate the policy after training
n_episodes = 3
learner_rewards_after_training, _ = evaluate_policy(
self.gen_algo,
self.venv,
n_episodes,
return_episode_rewards=True,
)
reward_mean = np.mean(learner_rewards_after_training)
reward_std = np.std(learner_rewards_after_training)
# MODIFIED: Save the model
if r % 10 == 0:
stats = self.logger._logger.stats
ts_partial = f"{self.ts_now}_{r}"
# print(f"Saving the model with timestamp: {ts_partial}")
save_model(
self.gen_algo,
self._reward_net,
stats,
self.save_path,
ts_partial,
)
print(f"Reward mean: {reward_mean}")
# Save the model if the reward is the highest so far
if (
reward_mean - (0.95 * reward_std / n_episodes)
>= self.highest_reward
):
ts_highest = f"{self.ts_now}_best"
self.highest_reward = reward_mean - (
0.95 * reward_std / n_episodes
)
print(f" New highest reward: {reward_mean}. Saving the model.")
save_model(
self.gen_algo,
self._reward_net,
stats,
self.save_path,
ts_highest,
)
@overload
def _torchify_array(self, ndarray: np.ndarray) -> th.Tensor: ...
@overload
def _torchify_array(self, ndarray: None) -> None: ...
def _torchify_array(
self, ndarray: Optional[np.ndarray]
) -> Optional[th.Tensor]:
if ndarray is not None:
return th.as_tensor(ndarray, device=self.reward_train.device)
return None
def _get_log_policy_act_prob(
self,
obs_th: th.Tensor,
acts_th: th.Tensor,
) -> Optional[th.Tensor]:
"""Evaluates the given actions on the given observations.
Args:
obs_th: A batch of observations.
acts_th: A batch of actions.
Returns:
A batch of log policy action probabilities.
"""
if isinstance(self.policy, policies.ActorCriticPolicy):
# policies.ActorCriticPolicy has a concrete implementation of
# evaluate_actions to generate log_policy_act_prob given obs and actions.
_, log_policy_act_prob_th, _ = self.policy.evaluate_actions(
obs_th,
acts_th,
)
elif isinstance(self.policy, sac_policies.SACPolicy):
gen_algo_actor = self.policy.actor
assert gen_algo_actor is not None
# generate log_policy_act_prob from SAC actor.
mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(
obs_th
)
assert isinstance(
gen_algo_actor.action_dist,
distributions.SquashedDiagGaussianDistribution,
) # Note: this is just a hint to mypy
distribution = gen_algo_actor.action_dist.proba_distribution(
mean_actions,
log_std,
)
# SAC applies a squashing function to bound the actions to a finite range
# `acts_th` need to be scaled accordingly before computing log prob.
# Scale actions only if the policy squashes outputs.
assert self.policy.squash_output
scaled_acts = self.policy.scale_action(acts_th.numpy(force=True))
scaled_acts_th = th.as_tensor(
scaled_acts, device=mean_actions.device
)
log_policy_act_prob_th = distribution.log_prob(scaled_acts_th)
else:
return None
return log_policy_act_prob_th
def _make_disc_train_batches(
self,
*,
gen_samples: Optional[Mapping] = None,
expert_samples: Optional[Mapping] = None,
) -> Iterator[Mapping[str, th.Tensor]]:
"""Build and return training minibatches for the next discriminator update.
Args:
gen_samples: Same as in `train_disc`.
expert_samples: Same as in `train_disc`.
Yields:
The training minibatch: state, action, next state, dones, labels
and policy log-probabilities.
Raises:
RuntimeError: Empty generator replay buffer.
ValueError: `gen_samples` or `expert_samples` batch size is
different from `self.demo_batch_size`.
"""
batch_size = self.demo_batch_size
if expert_samples is None:
expert_samples = self._next_expert_batch()
if gen_samples is None:
if self._gen_replay_buffer.size() == 0:
raise RuntimeError(
"No generator samples for training. "
"Call `train_gen()` first.",
)
gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size)
gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass)
if not (
len(gen_samples["obs"]) == len(expert_samples["obs"]) == batch_size
):
raise ValueError(
"Need to have exactly `demo_batch_size` number of expert and "
"generator samples, each. "
f"(n_gen={len(gen_samples['obs'])} "
f"n_expert={len(expert_samples['obs'])} "
f"demo_batch_size={batch_size})",
)
# Guarantee that Mapping arguments are in mutable form.
expert_samples = dict(expert_samples)
gen_samples = dict(gen_samples)
# MODIFIED: Balance expert samples
acts = expert_samples["acts"]
obs = expert_samples["obs"]
next_obs = expert_samples["next_obs"]
dones = expert_samples["dones"]
infos = expert_samples["infos"]
if len(np.unique(acts)) > 1:
# Get the indices of the actions
id_5 = np.where(acts == 5)[0].tolist()
id_17 = np.where(acts == 17)[0].tolist()
# Compute the number of times the list needs to be repeated
a = int(np.ceil(len(acts) / (2 * len(id_5))))
b = int(np.ceil(len(acts) / (2 * len(id_17))))
# Sample the indices from the original list
id_5 = np.tile(id_5, a).tolist()[: int(len(acts) / 2)]
id_17 = np.tile(id_17, b).tolist()[: int(len(acts) / 2)]
# Join the indices and shuffle them randomly
joint_id = id_5 + id_17
joint_id = RNG.permutation(joint_id)
# Replace the samples
expert_samples["acts"] = acts[joint_id]
expert_samples["obs"] = obs[joint_id]
expert_samples["next_obs"] = next_obs[joint_id]
expert_samples["dones"] = dones[joint_id]
expert_samples["infos"] = infos
# Convert applicable Tensor values to NumPy.
for field in dataclasses.fields(types.Transitions):
k = field.name
if k == "infos":
continue
for d in [gen_samples, expert_samples]:
if isinstance(d[k], th.Tensor):
d[k] = d[k].detach().numpy()
assert isinstance(gen_samples["obs"], np.ndarray)
assert isinstance(expert_samples["obs"], np.ndarray)
# Check dimensions.
assert batch_size == len(expert_samples["acts"])
assert batch_size == len(expert_samples["next_obs"])
assert batch_size == len(gen_samples["acts"])
assert batch_size == len(gen_samples["next_obs"])
for start in range(0, batch_size, self.demo_minibatch_size):
end = start + self.demo_minibatch_size
# take minibatch slice (this creates views so no memory issues)
expert_batch = {k: v[start:end] for k, v in expert_samples.items()}
gen_batch = {k: v[start:end] for k, v in gen_samples.items()}
# Concatenate rollouts, and label each row as expert or generator.
obs = np.concatenate([expert_batch["obs"], gen_batch["obs"]])
acts = np.concatenate([expert_batch["acts"], gen_batch["acts"]])
next_obs = np.concatenate(
[expert_batch["next_obs"], gen_batch["next_obs"]]
)
dones = np.concatenate([expert_batch["dones"], gen_batch["dones"]])
# notice that the labels use the convention that expert samples are
# labelled with 1 and generator samples with 0.
labels_expert_is_one = np.concatenate(
[
np.ones(self.demo_minibatch_size, dtype=int),
np.zeros(self.demo_minibatch_size, dtype=int),
],
)
# Calculate generator-policy log probabilities.
with th.no_grad():
obs_th = th.as_tensor(obs, device=self.gen_algo.device)
acts_th = th.as_tensor(acts, device=self.gen_algo.device)
log_policy_act_prob = self._get_log_policy_act_prob(
obs_th, acts_th
)
if log_policy_act_prob is not None:
assert (
len(log_policy_act_prob) == 2 * self.demo_minibatch_size
)
log_policy_act_prob = log_policy_act_prob.reshape(
(2 * self.demo_minibatch_size,),
)
del obs_th, acts_th # unneeded
(
obs_th,
acts_th,
next_obs_th,
dones_th,
) = self.reward_train.preprocess(
obs,
acts,
next_obs,
dones,
)
batch_dict = {
"state": obs_th,
"action": acts_th,
"next_state": next_obs_th,
"done": dones_th,
"labels_expert_is_one": self._torchify_array(
labels_expert_is_one
),
"log_policy_act_prob": log_policy_act_prob,
}
yield batch_dict