-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_actorcritic_training.py
39 lines (35 loc) · 2.6 KB
/
test_actorcritic_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import pytest
from attrdict import AttrDict
import recommerce.market.circular.circular_sim_market as circular_market
import recommerce.market.linear.linear_sim_market as linear_market
import recommerce.rl.actorcritic.actorcritic_agent as actorcritic_agent
from recommerce.configuration.hyperparameter_config import HyperparameterConfigLoader
from recommerce.rl.actorcritic.actorcritic_training import ActorCriticTrainer
test_scenarios = [
(linear_market.LinearEconomyDuopoly, actorcritic_agent.DiscreteActorCriticAgent, True),
(linear_market.LinearEconomyDuopoly, actorcritic_agent.ContinuousActorCriticAgentFixedOneStd, True),
(linear_market.LinearEconomyDuopoly, actorcritic_agent.ContinuousActorCriticAgentEstimatingStd, False),
(linear_market.LinearEconomyOligopoly, actorcritic_agent.DiscreteActorCriticAgent, False),
(linear_market.LinearEconomyOligopoly, actorcritic_agent.ContinuousActorCriticAgentFixedOneStd, False),
(linear_market.LinearEconomyOligopoly, actorcritic_agent.ContinuousActorCriticAgentEstimatingStd, True),
(circular_market.CircularEconomyMonopoly, actorcritic_agent.DiscreteActorCriticAgent, True),
(circular_market.CircularEconomyMonopoly, actorcritic_agent.ContinuousActorCriticAgentFixedOneStd, False),
(circular_market.CircularEconomyMonopoly, actorcritic_agent.ContinuousActorCriticAgentEstimatingStd, True),
(circular_market.CircularEconomyRebuyPriceMonopoly, actorcritic_agent.DiscreteActorCriticAgent, True),
(circular_market.CircularEconomyRebuyPriceMonopoly, actorcritic_agent.ContinuousActorCriticAgentFixedOneStd, False),
(circular_market.CircularEconomyRebuyPriceMonopoly, actorcritic_agent.ContinuousActorCriticAgentEstimatingStd, True),
(circular_market.CircularEconomyRebuyPriceDuopoly, actorcritic_agent.DiscreteActorCriticAgent, False),
(circular_market.CircularEconomyRebuyPriceDuopoly, actorcritic_agent.ContinuousActorCriticAgentFixedOneStd, True),
(circular_market.CircularEconomyRebuyPriceDuopoly, actorcritic_agent.ContinuousActorCriticAgentEstimatingStd, False)
]
@pytest.mark.training
@pytest.mark.slow
@pytest.mark.parametrize('market_class, agent_class, verbose', test_scenarios)
def test_training_configurations(market_class, agent_class, verbose):
config_market: AttrDict = HyperparameterConfigLoader.load('market_config', circular_market.CircularEconomyRebuyPriceMonopoly)
config_rl: AttrDict = HyperparameterConfigLoader.load('actor_critic_config', actorcritic_agent.ContinuousActorCriticAgentFixedOneStd)
config_rl.batch_size = 8
ActorCriticTrainer(market_class, agent_class, config_market, config_rl).train_agent(
verbose=verbose,
number_of_training_steps=120,
total_envs=64)