22
22
warnings .filterwarnings ('ignore' )
23
23
24
24
25
- class PerStepCheck (BaseCallback ):
25
+ class RecommerceCallback (BaseCallback ):
26
26
"""
27
27
Callback for saving a model (the check is done every `check_freq` steps)
28
28
based on the training reward (in practice, we recommend using `EvalCallback`).
29
29
"""
30
- def __init__ (self , agent_class , marketplace_class , log_dir_prepend = '' , training_steps = 10000 , iteration_length = 500 ):
30
+ def __init__ (self , agent_class , marketplace_class , log_dir_prepend = '' , training_steps = 10000 ,
31
+ iteration_length = 500 , file_ending = 'zip' , signature = 'training' ):
31
32
assert issubclass (agent_class , ReinforcementLearningAgent )
32
33
assert issubclass (marketplace_class , SimMarket )
33
34
assert isinstance (log_dir_prepend , str ), \
34
35
f'log_dir_prepend should be a string, but { log_dir_prepend } is { type (log_dir_prepend )} '
35
36
assert isinstance (training_steps , int ) and training_steps > 0
36
37
assert isinstance (iteration_length , int ) and iteration_length > 0
37
- super (PerStepCheck , self ).__init__ (True )
38
+ super (RecommerceCallback , self ).__init__ (True )
38
39
self .best_mean_interim_reward = None
39
40
self .best_mean_overall_reward = None
40
41
self .marketplace_class = marketplace_class
41
42
self .agent_class = agent_class
42
43
self .iteration_length = iteration_length
44
+ self .file_ending = file_ending
45
+ self .signature = signature
43
46
self .tqdm_instance = trange (training_steps )
44
47
self .saved_parameter_paths = []
48
+ self .last_finished_episode = 0
45
49
signal .signal (signal .SIGINT , self ._signal_handler )
46
50
47
51
self .initialize_io_related (log_dir_prepend )
@@ -63,34 +67,51 @@ def initialize_io_related(self, log_dir_prepend) -> None:
63
67
"""
64
68
ut .ensure_results_folders_exist ()
65
69
self .curr_time = time .strftime ('%b%d_%H-%M-%S' )
66
- self .signature = 'Stable_Baselines_Training'
67
70
self .writer = SummaryWriter (log_dir = os .path .join (PathManager .results_path , 'runs' , f'{ log_dir_prepend } training_{ self .curr_time } ' ))
68
71
path_name = f'{ self .signature } _{ self .curr_time } '
69
72
self .save_path = os .path .join (PathManager .results_path , 'trainedModels' , log_dir_prepend + path_name )
70
73
os .makedirs (os .path .abspath (self .save_path ), exist_ok = True )
71
- self .tmp_parameters = os .path .join (self .save_path , 'tmp_model.zip ' )
74
+ self .tmp_parameters = os .path .join (self .save_path , f 'tmp_model.{ self . file_ending } ' )
72
75
73
- def _on_step (self ) -> bool :
76
+ def _on_step (self , finished_episodes : int = None , mean_return : float = None ) -> bool :
74
77
"""
75
- This method is called at every step by the stable baselines agents.
78
+ This method is called during training after step in the environment is called.
79
+ If you don't provide finished_episodes and mean_return, the agent will conclude this from the number of timesteps.
80
+ Note that you must provide finished_episodes if and only if you provide mean_return.
81
+
82
+ Args:
83
+ finished_episodes (int, optional): The episodes that are already finished. Defaults to None.
84
+ mean_return (float, optional): The recieved return over the last episodes. Defaults to None.
85
+
86
+ Returns:
87
+ bool: True should be returned. False will be interpreted as error.
76
88
"""
89
+ assert (finished_episodes is None ) == (mean_return is None ), 'finished_episodes must be exactly None if mean_return is None'
77
90
self .tqdm_instance .update ()
78
- if (self .num_timesteps - 1 ) % config .episode_length != 0 or self .num_timesteps <= config .episode_length :
91
+ if finished_episodes is None :
92
+ finished_episodes = self .num_timesteps // config .episode_length
93
+ x , y = ts2xy (load_results (self .save_path ), 'timesteps' )
94
+ if len (x ) <= 0 :
95
+ return True
96
+ assert len (x ) == len (y )
97
+ mean_return = np .mean (y [- 100 :])
98
+ assert isinstance (finished_episodes , int )
99
+ assert isinstance (mean_return , float )
100
+
101
+ assert finished_episodes >= self .last_finished_episode
102
+ if finished_episodes == self .last_finished_episode or finished_episodes < 5 :
79
103
return True
80
- self .tqdm_instance .refresh ()
81
- finished_episodes = self .num_timesteps // config .episode_length
82
- x , y = ts2xy (load_results (self .save_path ), 'timesteps' )
83
- assert len (x ) > 0 and len (x ) == len (y )
84
- mean_reward = np .mean (y [- 100 :])
104
+ else :
105
+ self .last_finished_episode = finished_episodes
85
106
86
107
# consider print info
87
108
if (finished_episodes ) % 10 == 0 :
88
- tqdm .write (f'{ self .num_timesteps } : { finished_episodes } episodes trained, mean return { mean_reward :.3f} ' )
109
+ tqdm .write (f'{ self .num_timesteps } : { finished_episodes } episodes trained, mean return { mean_return :.3f} ' )
89
110
90
111
# consider update best model
91
- if self .best_mean_interim_reward is None or mean_reward > self .best_mean_interim_reward + 15 :
112
+ if self .best_mean_interim_reward is None or mean_return > self .best_mean_interim_reward + 15 :
92
113
self .model .save (self .tmp_parameters )
93
- self .best_mean_interim_reward = mean_reward
114
+ self .best_mean_interim_reward = mean_return
94
115
if self .best_mean_overall_reward is None or self .best_mean_interim_reward > self .best_mean_overall_reward :
95
116
if self .best_mean_overall_reward is not None :
96
117
tqdm .write (f'Best overall reward updated { self .best_mean_overall_reward :.3f} -> { self .best_mean_interim_reward :.3f} ' )
@@ -105,23 +126,23 @@ def _on_step(self) -> bool:
105
126
def _on_training_end (self ) -> None :
106
127
self .tqdm_instance .close ()
107
128
if self .best_mean_interim_reward is not None :
108
- finished_episodes = self .num_timesteps // config .episode_length
109
- self .save_parameters (finished_episodes )
129
+ self .save_parameters (self .last_finished_episode )
110
130
111
131
# analyze trained agents
112
132
if len (self .saved_parameter_paths ) == 0 :
113
133
print ('No agents saved! Nothing to monitor.' )
114
134
return
115
135
monitor = Monitor ()
116
136
agent_list = [(self .agent_class , [parameter_path ]) for parameter_path in self .saved_parameter_paths ]
117
- monitor .configurator .setup_monitoring (False , 250 , 250 , self .marketplace_class , agent_list , support_continuous_action_space = True )
137
+ monitor .configurator .setup_monitoring (False , 250 , 250 , self .marketplace_class , agent_list ,
138
+ support_continuous_action_space = hasattr (self .model , 'env' ))
118
139
rewards = monitor .run_marketplace ()
119
140
episode_numbers = [int (parameter_path [- 9 :][:5 ]) for parameter_path in self .saved_parameter_paths ]
120
141
Evaluator (monitor .configurator ).evaluate_session (rewards , episode_numbers )
121
142
122
143
def save_parameters (self , finished_episodes : int ):
123
144
assert isinstance (finished_episodes , int )
124
- path_to_parameters = os .path .join (self .save_path , f'{ self .signature } _{ finished_episodes :05d} .zip ' )
145
+ path_to_parameters = os .path .join (self .save_path , f'{ self .signature } _{ finished_episodes :05d} .{ self . file_ending } ' )
125
146
os .rename (self .tmp_parameters , path_to_parameters )
126
147
self .saved_parameter_paths .append (path_to_parameters )
127
148
tqdm .write (f'I write the interim model after { finished_episodes } episodes to the disk.' )
0 commit comments