Skip to content

Commit 2c616dc

Browse files
authored
[E] RL agent refactoring I (#410)
* moved experience_buffer to q_learning * renamed callback and moved to superfolder * refactored training to work with callback * deleted interim networks and deleted mocking * deleted interim networks and deleted mocking * adjusted tqdm output * updated docstrings and renamed Callback to RecommerceCallback * renamed mean_reward to mean_return according to rl terminology * fixed assert * changed naming * included small comments
1 parent 091eee6 commit 2c616dc

12 files changed

+100
-201
lines changed

recommerce/configuration/hyperparameter_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _set_sim_market_variables(self, config: dict) -> None:
275275
self.production_price = config['production_price']
276276
self.storage_cost_per_product = config['storage_cost_per_product']
277277

278-
self.mean_reward_bound = self.episode_length * self.max_price * self.number_of_customers
278+
self.mean_return_bound = self.episode_length * self.max_price * self.number_of_customers
279279

280280

281281
class HyperparameterConfigLoader():

recommerce/rl/actorcritic/actorcritic_agent.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from abc import ABC, abstractmethod
32

43
import numpy as np
@@ -71,26 +70,16 @@ def policy(self, observation, verbose=False, raw_action=False) -> None: # pragm
7170
"""
7271
raise NotImplementedError('This method is abstract. Use a subclass')
7372

74-
def sync_to_best_interim(self):
75-
self.best_interim_actor_net.load_state_dict(self.actor_net.state_dict())
76-
self.best_interim_critic_net.load_state_dict(self.critic_net.state_dict())
77-
78-
def save(self, model_path, model_name) -> None:
73+
def save(self, model_path: str) -> None:
7974
"""
8075
Save a trained model to the specified folder within 'trainedModels'.
81-
For each model an actor and a critic net will be saved.
82-
This method is copied from our Q-Learning Agent
76+
For each model only the actor net will be saved.
8377
8478
Args:
85-
model_path (str): The path to the folder within 'trainedModels' where the model should be saved.
86-
model_name (str): The name of the .dat file of this specific model.
79+
model_path (str): The path including the name where the model should be saved.
8780
"""
88-
assert model_name.endswith('.dat'), f'the modelname must end in ".dat": {model_name}'
89-
assert os.path.exists(model_path), f'the specified path does not exist: {model_path}'
90-
actor_path = os.path.join(model_path, f'actor_parameters{model_name}')
91-
torch.save(self.best_interim_actor_net.state_dict(), actor_path)
92-
torch.save(self.best_interim_critic_net.state_dict(), os.path.join(model_path, 'critic_parameters' + model_name))
93-
return actor_path
81+
assert model_path.endswith('.dat'), f'the modelname must end in ".dat": {model_path}'
82+
torch.save(self.actor_net.state_dict(), model_path)
9483

9584
def train_batch(self, states, actions, rewards, next_states, regularization=False):
9685
"""
@@ -176,11 +165,9 @@ class DiscreteActorCriticAgent(ActorCriticAgent):
176165
def initialize_models_and_optimizer(self, n_observations, n_actions, network_architecture):
177166
self.actor_net = network_architecture(n_observations, n_actions).to(self.device)
178167
self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=0.0000025)
179-
self.best_interim_actor_net = network_architecture(n_observations, n_actions).to(self.device)
180168
self.critic_net = network_architecture(n_observations, 1).to(self.device)
181169
self.critic_optimizer = torch.optim.Adam(self.critic_net.parameters(), lr=0.00025)
182170
self.critic_tgt_net = network_architecture(n_observations, 1).to(self.device)
183-
self.best_interim_critic_net = self.critic_tgt_net = network_architecture(n_observations, 1).to(self.device)
184171

185172
def policy(self, observation, verbose=False, raw_action=False):
186173
observation = torch.Tensor(np.array(observation)).to(self.device)
@@ -235,11 +222,9 @@ def initialize_models_and_optimizer(self, n_observations, n_actions, network_arc
235222
self.n_actions = n_actions
236223
self.actor_net = network_architecture(n_observations, self.n_actions).to(self.device)
237224
self.actor_optimizer = torch.optim.Adam(self.actor_net.parameters(), lr=0.0002)
238-
self.best_interim_actor_net = network_architecture(n_observations, self.n_actions).to(self.device)
239225
self.critic_net = network_architecture(n_observations, 1).to(self.device)
240226
self.critic_optimizer = torch.optim.Adam(self.critic_net.parameters(), lr=0.002)
241227
self.critic_tgt_net = network_architecture(n_observations, 1).to(self.device)
242-
self.best_interim_critic_net = network_architecture(n_observations, 1).to(self.device)
243228

244229
@abstractmethod
245230
def transform_network_output(self, number_outputs, network_result):

recommerce/rl/actorcritic/actorcritic_training.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import torch
5-
from tqdm.auto import trange
65

76
import recommerce.configuration.utils as ut
87
import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent
@@ -12,7 +11,7 @@
1211

1312
class ActorCriticTrainer(RLTrainer):
1413
def trainer_agent_fit(self) -> bool:
15-
return isinstance(self.RL_agent, actorcritic_agent.ActorCriticAgent)
14+
return issubclass(self.agent_class, actorcritic_agent.ActorCriticAgent)
1615

1716
def choose_random_envs(self, total_envs) -> set:
1817
"""
@@ -41,6 +40,7 @@ def train_agent(self, number_of_training_steps=200, verbose=False, total_envs=12
4140
verbose (bool, optional): Should additional information about agent steps be written to the tensorboard? Defaults to False.
4241
total_envs (int, optional): The number of environments you use in parallel to fulfill the iid assumption. Defaults to 128.
4342
"""
43+
self.initialize_callback(number_of_training_steps * config.batch_size)
4444

4545
all_dicts = []
4646
if verbose:
@@ -50,25 +50,28 @@ def train_agent(self, number_of_training_steps=200, verbose=False, total_envs=12
5050
all_policy_losses = []
5151

5252
finished_episodes = 0
53+
mean_return = -np.inf
54+
self.callback.num_timesteps = 0
5355
environments = [self.marketplace_class() for _ in range(total_envs)]
5456
info_accumulators = [None for _ in range(total_envs)]
5557

56-
for step_number in trange(number_of_training_steps, unit=' frames', leave=False):
58+
for step_number in range(number_of_training_steps):
5759
chosen_envs = self.choose_random_envs(total_envs)
5860

5961
states = []
6062
actions = []
6163
rewards = []
6264
states_dash = []
6365
for env in chosen_envs:
66+
self.callback.num_timesteps += 1
6467
state = environments[env]._observation()
6568
if not verbose:
66-
action = self.RL_agent.policy(state, verbose=False, raw_action=True)
69+
action = self.callback.model.policy(state, verbose=False, raw_action=True)
6770
else:
68-
action, net_output, v_estimate = self.RL_agent.policy(state, verbose=True, raw_action=True)
71+
action, net_output, v_estimate = self.callback.model.policy(state, verbose=True, raw_action=True)
6972
all_network_outputs.append(net_output.reshape(-1))
7073
all_v_estimates.append(v_estimate)
71-
next_state, reward, is_done, info = environments[env].step(self.RL_agent.agent_output_to_market_form(action))
74+
next_state, reward, is_done, info = environments[env].step(self.callback.model.agent_output_to_market_form(action))
7275

7376
states.append(state)
7477
actions.append(action)
@@ -92,16 +95,16 @@ def train_agent(self, number_of_training_steps=200, verbose=False, total_envs=12
9295
averaged_info[f'verbose/min/information_{str(action_num)}'] = np.min(myactions[:, action_num])
9396
averaged_info[f'verbose/max/information_{str(action_num)}'] = np.max(myactions[:, action_num])
9497

95-
ut.write_dict_to_tensorboard(self.writer, averaged_info, finished_episodes, is_cumulative=True)
98+
ut.write_dict_to_tensorboard(self.callback.writer, averaged_info, finished_episodes, is_cumulative=True)
9699

97100
environments[env].reset()
98101
info_accumulators[env] = None
99102

100-
self.consider_print_info(step_number, finished_episodes, averaged_info)
101-
self.consider_update_best_model(averaged_info)
102-
self.consider_save_model(finished_episodes)
103+
mean_return = averaged_info['profits/all']['vendor_0']
103104

104-
policy_loss, valueloss = self.RL_agent.train_batch(
105+
self.callback._on_step(finished_episodes, mean_return)
106+
107+
policy_loss, valueloss = self.callback.model.train_batch(
105108
torch.Tensor(np.array(states)),
106109
torch.from_numpy(np.array(actions, dtype=np.int64)),
107110
torch.Tensor(np.array(rewards)),
@@ -112,6 +115,4 @@ def train_agent(self, number_of_training_steps=200, verbose=False, total_envs=12
112115

113116
self.consider_sync_tgt_net(step_number)
114117

115-
self.consider_save_model(finished_episodes, force=True)
116-
self.analyze_trained_agents()
117-
self._end_of_training()
118+
self.callback._on_training_end()

recommerce/rl/stable_baselines/stable_baselines_callback.py renamed to recommerce/rl/callback.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,30 @@
2222
warnings.filterwarnings('ignore')
2323

2424

25-
class PerStepCheck(BaseCallback):
25+
class RecommerceCallback(BaseCallback):
2626
"""
2727
Callback for saving a model (the check is done every `check_freq` steps)
2828
based on the training reward (in practice, we recommend using `EvalCallback`).
2929
"""
30-
def __init__(self, agent_class, marketplace_class, log_dir_prepend='', training_steps=10000, iteration_length=500):
30+
def __init__(self, agent_class, marketplace_class, log_dir_prepend='', training_steps=10000,
31+
iteration_length=500, file_ending='zip', signature='training'):
3132
assert issubclass(agent_class, ReinforcementLearningAgent)
3233
assert issubclass(marketplace_class, SimMarket)
3334
assert isinstance(log_dir_prepend, str), \
3435
f'log_dir_prepend should be a string, but {log_dir_prepend} is {type(log_dir_prepend)}'
3536
assert isinstance(training_steps, int) and training_steps > 0
3637
assert isinstance(iteration_length, int) and iteration_length > 0
37-
super(PerStepCheck, self).__init__(True)
38+
super(RecommerceCallback, self).__init__(True)
3839
self.best_mean_interim_reward = None
3940
self.best_mean_overall_reward = None
4041
self.marketplace_class = marketplace_class
4142
self.agent_class = agent_class
4243
self.iteration_length = iteration_length
44+
self.file_ending = file_ending
45+
self.signature = signature
4346
self.tqdm_instance = trange(training_steps)
4447
self.saved_parameter_paths = []
48+
self.last_finished_episode = 0
4549
signal.signal(signal.SIGINT, self._signal_handler)
4650

4751
self.initialize_io_related(log_dir_prepend)
@@ -63,34 +67,51 @@ def initialize_io_related(self, log_dir_prepend) -> None:
6367
"""
6468
ut.ensure_results_folders_exist()
6569
self.curr_time = time.strftime('%b%d_%H-%M-%S')
66-
self.signature = 'Stable_Baselines_Training'
6770
self.writer = SummaryWriter(log_dir=os.path.join(PathManager.results_path, 'runs', f'{log_dir_prepend}training_{self.curr_time}'))
6871
path_name = f'{self.signature}_{self.curr_time}'
6972
self.save_path = os.path.join(PathManager.results_path, 'trainedModels', log_dir_prepend + path_name)
7073
os.makedirs(os.path.abspath(self.save_path), exist_ok=True)
71-
self.tmp_parameters = os.path.join(self.save_path, 'tmp_model.zip')
74+
self.tmp_parameters = os.path.join(self.save_path, f'tmp_model.{self.file_ending}')
7275

73-
def _on_step(self) -> bool:
76+
def _on_step(self, finished_episodes: int = None, mean_return: float = None) -> bool:
7477
"""
75-
This method is called at every step by the stable baselines agents.
78+
This method is called during training after step in the environment is called.
79+
If you don't provide finished_episodes and mean_return, the agent will conclude this from the number of timesteps.
80+
Note that you must provide finished_episodes if and only if you provide mean_return.
81+
82+
Args:
83+
finished_episodes (int, optional): The episodes that are already finished. Defaults to None.
84+
mean_return (float, optional): The recieved return over the last episodes. Defaults to None.
85+
86+
Returns:
87+
bool: True should be returned. False will be interpreted as error.
7688
"""
89+
assert (finished_episodes is None) == (mean_return is None), 'finished_episodes must be exactly None if mean_return is None'
7790
self.tqdm_instance.update()
78-
if (self.num_timesteps - 1) % config.episode_length != 0 or self.num_timesteps <= config.episode_length:
91+
if finished_episodes is None:
92+
finished_episodes = self.num_timesteps // config.episode_length
93+
x, y = ts2xy(load_results(self.save_path), 'timesteps')
94+
if len(x) <= 0:
95+
return True
96+
assert len(x) == len(y)
97+
mean_return = np.mean(y[-100:])
98+
assert isinstance(finished_episodes, int)
99+
assert isinstance(mean_return, float)
100+
101+
assert finished_episodes >= self.last_finished_episode
102+
if finished_episodes == self.last_finished_episode or finished_episodes < 5:
79103
return True
80-
self.tqdm_instance.refresh()
81-
finished_episodes = self.num_timesteps // config.episode_length
82-
x, y = ts2xy(load_results(self.save_path), 'timesteps')
83-
assert len(x) > 0 and len(x) == len(y)
84-
mean_reward = np.mean(y[-100:])
104+
else:
105+
self.last_finished_episode = finished_episodes
85106

86107
# consider print info
87108
if (finished_episodes) % 10 == 0:
88-
tqdm.write(f'{self.num_timesteps}: {finished_episodes} episodes trained, mean return {mean_reward:.3f}')
109+
tqdm.write(f'{self.num_timesteps}: {finished_episodes} episodes trained, mean return {mean_return:.3f}')
89110

90111
# consider update best model
91-
if self.best_mean_interim_reward is None or mean_reward > self.best_mean_interim_reward + 15:
112+
if self.best_mean_interim_reward is None or mean_return > self.best_mean_interim_reward + 15:
92113
self.model.save(self.tmp_parameters)
93-
self.best_mean_interim_reward = mean_reward
114+
self.best_mean_interim_reward = mean_return
94115
if self.best_mean_overall_reward is None or self.best_mean_interim_reward > self.best_mean_overall_reward:
95116
if self.best_mean_overall_reward is not None:
96117
tqdm.write(f'Best overall reward updated {self.best_mean_overall_reward:.3f} -> {self.best_mean_interim_reward:.3f}')
@@ -105,23 +126,23 @@ def _on_step(self) -> bool:
105126
def _on_training_end(self) -> None:
106127
self.tqdm_instance.close()
107128
if self.best_mean_interim_reward is not None:
108-
finished_episodes = self.num_timesteps // config.episode_length
109-
self.save_parameters(finished_episodes)
129+
self.save_parameters(self.last_finished_episode)
110130

111131
# analyze trained agents
112132
if len(self.saved_parameter_paths) == 0:
113133
print('No agents saved! Nothing to monitor.')
114134
return
115135
monitor = Monitor()
116136
agent_list = [(self.agent_class, [parameter_path]) for parameter_path in self.saved_parameter_paths]
117-
monitor.configurator.setup_monitoring(False, 250, 250, self.marketplace_class, agent_list, support_continuous_action_space=True)
137+
monitor.configurator.setup_monitoring(False, 250, 250, self.marketplace_class, agent_list,
138+
support_continuous_action_space=hasattr(self.model, 'env'))
118139
rewards = monitor.run_marketplace()
119140
episode_numbers = [int(parameter_path[-9:][:5]) for parameter_path in self.saved_parameter_paths]
120141
Evaluator(monitor.configurator).evaluate_session(rewards, episode_numbers)
121142

122143
def save_parameters(self, finished_episodes: int):
123144
assert isinstance(finished_episodes, int)
124-
path_to_parameters = os.path.join(self.save_path, f'{self.signature}_{finished_episodes:05d}.zip')
145+
path_to_parameters = os.path.join(self.save_path, f'{self.signature}_{finished_episodes:05d}.{self.file_ending}')
125146
os.rename(self.tmp_parameters, path_to_parameters)
126147
self.saved_parameter_paths.append(path_to_parameters)
127148
tqdm.write(f'I write the interim model after {finished_episodes} episodes to the disk.')

recommerce/rl/q_learning/q_learning_agent.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import collections
2-
import os
32
import random
43
from abc import ABC, abstractmethod
54

@@ -11,7 +10,7 @@
1110
from recommerce.market.circular.circular_vendors import CircularAgent
1211
from recommerce.market.linear.linear_vendors import LinearAgent
1312
from recommerce.market.sim_market import SimMarket
14-
from recommerce.rl.experience_buffer import ExperienceBuffer
13+
from recommerce.rl.q_learning.experience_buffer import ExperienceBuffer
1514
from recommerce.rl.reinforcement_learning_agent import ReinforcementLearningAgent
1615

1716

@@ -48,7 +47,6 @@ def __init__(
4847
self.name = name
4948
print(f'I initiate a QLearningAgent using {self.device} device')
5049
self.net = network_architecture(n_observations, n_actions).to(self.device)
51-
self.best_interim_net = network_architecture(n_observations, n_actions)
5250
if load_path:
5351
self.net.load_state_dict(torch.load(load_path, map_location=self.device))
5452
if optim:
@@ -120,22 +118,15 @@ def calc_loss(self, batch, device='cpu'):
120118
expected_state_action_values = next_state_values * config.gamma + rewards_v
121119
return torch.nn.MSELoss()(state_action_values, expected_state_action_values), state_action_values.mean()
122120

123-
def sync_to_best_interim(self):
124-
self.best_interim_net.load_state_dict(self.net.state_dict())
125-
126-
def save(self, model_path: str, model_name: str) -> None:
121+
def save(self, model_path: str) -> None:
127122
"""
128123
Save a trained model to the specified folder within 'trainedModels'.
129124
130125
Args:
131-
model_path (str): The path to the folder within 'trainedModels' where the model should be saved.
132-
model_name (str): The name of the .dat file of this specific model.
126+
model_path (str): The path including the name where the model should be saved.
133127
"""
134-
assert model_name.endswith('.dat'), f'the modelname must end in ".dat": {model_name}'
135-
assert os.path.exists(model_path), f'the specified path does not exist: {model_path}'
136-
parameters_path = os.path.join(model_path, model_name)
137-
torch.save(self.best_interim_net.state_dict(), parameters_path)
138-
return parameters_path
128+
assert model_path.endswith('.dat'), f'the modelname must end in ".dat": {model_path}'
129+
torch.save(self.net.state_dict(), model_path)
139130

140131

141132
class QLearningLEAgent(QLearningAgent, LinearAgent):

0 commit comments

Comments
 (0)