diff --git a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py index 10e566ff9..67bf3ceb0 100644 --- a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_env.py @@ -3,6 +3,7 @@ from unittest.mock import patch, MagicMock from gymnasium.spaces import Box, Discrete import numpy as np +from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig from gym_csle_stopping_game.dao.stopping_game_state import StoppingGameState @@ -23,19 +24,19 @@ def setup_env(self) -> None: :return: None """ env_name = "test_env" - T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]]) - O = np.array([0, 1]) - Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]]) + T = StoppingGameUtil.transition_tensor(L=3, p=0) + O = StoppingGameUtil.observation_space(n=100) + Z = StoppingGameUtil.observation_tensor(n=100) R = np.zeros((2, 3, 3, 3)) - S = np.array([0, 1, 2]) - A1 = np.array([0, 1, 2]) - A2 = np.array([0, 1, 2]) + S = StoppingGameUtil.state_space() + A1 = StoppingGameUtil.defender_actions() + A2 = StoppingGameUtil.attacker_actions() L = 2 R_INT = 1 R_COST = 2 R_SLA = 3 R_ST = 4 - b1 = np.array([0.6, 0.4]) + b1 = StoppingGameUtil.b1() save_dir = "save_directory" checkpoint_traces_freq = 100 gamma = 0.9 @@ -69,12 +70,12 @@ def test_stopping_game_init_(self) -> None: :return: None """ - T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]]) - O = np.array([0, 1]) - A1 = np.array([0, 1, 2]) - A2 = np.array([0, 1, 2]) + T = StoppingGameUtil.transition_tensor(L=3, p=0) + O = StoppingGameUtil.observation_space(n=100) + A1 = StoppingGameUtil.defender_actions() + A2 = StoppingGameUtil.attacker_actions() L = 2 - b1 = np.array([0.6, 0.4]) + b1 = StoppingGameUtil.b1() attacker_observation_space = Box( low=np.array([0.0, 0.0, 0.0]), high=np.array([float(L), 1.0, 2.0]), @@ -304,7 +305,7 @@ def test_is_state_terminal(self) -> None: assert not env.is_state_terminal(state_tuple) with pytest.raises(ValueError): - env.is_state_terminal([1, 2, 3]) # type: ignore + env.is_state_terminal([1, 2, 3]) # type: ignore def test_get_observation_from_history(self) -> None: """ @@ -346,26 +347,6 @@ def test_step(self) -> None: :return: None """ env = StoppingGameEnv(self.config) - env.state = MagicMock() - env.state.s = 1 - env.state.l = 2 - env.state.t = 0 - env.state.attacker_observation.return_value = np.array([1, 2, 3]) - env.state.defender_observation.return_value = np.array([4, 5, 6]) - env.state.b = np.array([0.5, 0.5, 0.0]) - - env.trace = MagicMock() - env.trace.defender_rewards = [] - env.trace.attacker_rewards = [] - env.trace.attacker_actions = [] - env.trace.defender_actions = [] - env.trace.infos = [] - env.trace.states = [] - env.trace.beliefs = [] - env.trace.infrastructure_metrics = [] - env.trace.attacker_observations = [] - env.trace.defender_observations = [] - with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_state", return_value=2): with patch("gym_csle_stopping_game.util.stopping_game_util.StoppingGameUtil.sample_next_observation", @@ -376,7 +357,7 @@ def test_step(self) -> None: 1, ( np.array( - [[0.2, 0.8, 0.0], [0.6, 0.4, 0.0], [0.5, 0.5, 0.0]] + [[0.2, 0.8], [0.6, 0.4], [0.5, 0.5]] ), 2, ), @@ -384,24 +365,12 @@ def test_step(self) -> None: observations, rewards, terminated, truncated, info = env.step( action_profile ) - - assert (observations[0] == np.array([4, 5, 6])).all(), "Incorrect defender observations" - assert (observations[1] == np.array([1, 2, 3])).all(), "Incorrect attacker observations" + assert observations[0].all() == np.array([1, 0.7]).all(), "Incorrect defender observations" + assert observations[1].all() == np.array([1, 2, 3]).all(), "Incorrect attacker observations" assert rewards == (0, 0) assert not terminated assert not truncated - assert env.trace.defender_rewards[-1] == 0 - assert env.trace.attacker_rewards[-1] == 0 - assert env.trace.attacker_actions[-1] == 2 - assert env.trace.defender_actions[-1] == 1 - assert env.trace.infos[-1] == info - assert env.trace.states[-1] == 2 - print(env.trace.beliefs) - assert env.trace.beliefs[-1] == 0.7 - assert env.trace.infrastructure_metrics[-1] == 1 - assert (env.trace.attacker_observations[-1] == np.array([1, 2, 3])).all() - assert (env.trace.defender_observations[-1] == np.array([4, 5, 6])).all() - + def test_info(self) -> None: """ Tests the function of adding the cumulative reward and episode length to the info dict diff --git a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py index df461c511..40fbd6260 100644 --- a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_mdp_attacker_env.py @@ -5,8 +5,12 @@ from gym_csle_stopping_game.dao.stopping_game_attacker_mdp_config import ( StoppingGameAttackerMdpConfig, ) +from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv from csle_common.dao.training.policy import Policy +from csle_common.dao.training.random_policy import RandomPolicy +from csle_common.dao.training.player_type import PlayerType +from csle_common.dao.simulation_config.action import Action import pytest from unittest.mock import MagicMock import numpy as np @@ -25,19 +29,19 @@ def setup_env(self) -> None: :return: None """ env_name = "test_env" - T = np.array([[[0.1, 0.9], [0.4, 0.6]], [[0.7, 0.3], [0.2, 0.8]]]) - O = np.array([0, 1]) - Z = np.array([[[0.8, 0.2], [0.5, 0.5]], [[0.4, 0.6], [0.9, 0.1]]]) + T = StoppingGameUtil.transition_tensor(L=3, p=0) + O = StoppingGameUtil.observation_space(n=100) + Z = StoppingGameUtil.observation_tensor(n=100) R = np.zeros((2, 3, 3, 3)) - S = np.array([0, 1, 2]) - A1 = np.array([0, 1, 2]) - A2 = np.array([0, 1, 2]) + S = StoppingGameUtil.state_space() + A1 = StoppingGameUtil.defender_actions() + A2 = StoppingGameUtil.attacker_actions() L = 2 R_INT = 1 R_COST = 2 R_SLA = 3 R_ST = 4 - b1 = np.array([0.6, 0.4]) + b1 = StoppingGameUtil.b1() save_dir = "save_directory" checkpoint_traces_freq = 100 gamma = 0.9 @@ -107,9 +111,8 @@ def test_reset(self) -> None: ) env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config) - attacker_obs, info = env.reset() - assert env.latest_defender_obs.all() == np.array([2, 0.4]).all() # type: ignore - assert info == {} + info = env.reset() + assert info[-1] == {} def test_set_model(self) -> None: """ @@ -144,7 +147,7 @@ def test_set_state(self) -> None: ) env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config) - assert not env.set_state(1) # type: ignore + assert not env.set_state(1) # type: ignore def test_calculate_stage_policy(self) -> None: """ @@ -190,7 +193,7 @@ def test_get_attacker_dist(self) -> None: def test_render(self) -> None: """ Tests the function for rendering the environment - + :return: None """ defender_strategy = MagicMock(spec=Policy) @@ -317,7 +320,7 @@ def test_get_actions_from_particles(self) -> None: particles = [1, 2, 3] t = 0 observation = 0 - expected_actions = [0, 1, 2] + expected_actions = [0, 1] assert ( env.get_actions_from_particles(particles, t, observation) == expected_actions @@ -326,18 +329,32 @@ def test_get_actions_from_particles(self) -> None: def test_step(self) -> None: """ Tests the function for taking a step in the environment by executing the given action - + :return: None """ - defender_strategy = MagicMock(spec=Policy) + defender_stage_strategy = np.zeros((3, 2)) + defender_stage_strategy[0][0] = 0.9 + defender_stage_strategy[0][1] = 0.1 + defender_stage_strategy[1][0] = 0.9 + defender_stage_strategy[1][1] = 0.1 + defender_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A1)) + defender_strategy = RandomPolicy( + actions=defender_actions, + player_type=PlayerType.DEFENDER, + stage_policy_tensor=list(defender_stage_strategy), + ) attacker_mdp_config = StoppingGameAttackerMdpConfig( env_name="test_env", stopping_game_config=self.config, defender_strategy=defender_strategy, stopping_game_name="csle-stopping-game-v1", ) - env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config) - pi2 = np.array([[0.5, 0.5]]) - with pytest.raises(AssertionError): - env.step(pi2) + env.reset() + pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs), a2=0) # type: ignore + attacker_obs, reward, terminated, truncated, info = env.step(pi2) + assert isinstance(attacker_obs[0], float) # type: ignore + assert isinstance(terminated, bool) # type: ignore + assert isinstance(truncated, bool) # type: ignore + assert isinstance(reward, float) # type: ignore + assert isinstance(info, dict) # type: ignore diff --git a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py index c3c8da4c1..0f83ac013 100644 --- a/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py +++ b/simulation-system/libs/gym-csle-stopping-game/tests/test_stopping_game_pomdp_defender_env.py @@ -1,9 +1,14 @@ -from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv +from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import ( + StoppingGamePomdpDefenderEnv, +) from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig -from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig +from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import ( + StoppingGameDefenderPomdpConfig, +) from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil from csle_common.dao.training.policy import Policy +from csle_common.dao.simulation_config.action import Action from csle_common.dao.training.random_policy import RandomPolicy from csle_common.dao.training.player_type import PlayerType import pytest @@ -219,7 +224,7 @@ def test_set_state(self) -> None: stopping_game_name="csle-stopping-game-v1", ) env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config) - assert env.set_state(1) is None # type: ignore + assert env.set_state(1) is None # type: ignore def test_get_observation_from_history(self) -> None: """ @@ -301,7 +306,10 @@ def test_get_actions_from_particles(self) -> None: t = 0 observation = 0 expected_actions = [0, 1] - assert env.get_actions_from_particles(particles, t, observation) == expected_actions + assert ( + env.get_actions_from_particles(particles, t, observation) + == expected_actions + ) def test_step(self) -> None: """ @@ -315,8 +323,12 @@ def test_step(self) -> None: attacker_stage_strategy[1][0] = 0.9 attacker_stage_strategy[1][1] = 0.1 attacker_stage_strategy[2] = attacker_stage_strategy[1] - attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER, - stage_policy_tensor=list(attacker_stage_strategy)) + attacker_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A2)) + attacker_strategy = RandomPolicy( + actions=attacker_actions, + player_type=PlayerType.ATTACKER, + stage_policy_tensor=list(attacker_stage_strategy), + ) defender_pomdp_config = StoppingGameDefenderPomdpConfig( env_name="test_env", stopping_game_config=self.config, @@ -328,9 +340,9 @@ def test_step(self) -> None: env.reset() defender_obs, reward, terminated, truncated, info = env.step(a1) assert len(defender_obs) == 2 - assert isinstance(defender_obs[0], float) # type: ignore - assert isinstance(defender_obs[1], float) # type: ignore - assert isinstance(reward, float) # type: ignore - assert isinstance(terminated, bool) # type: ignore - assert isinstance(truncated, bool) # type: ignore - assert isinstance(info, dict) # type: ignore + assert isinstance(defender_obs[0], float) # type: ignore + assert isinstance(defender_obs[1], float) # type: ignore + assert isinstance(reward, float) # type: ignore + assert isinstance(terminated, bool) # type: ignore + assert isinstance(truncated, bool) # type: ignore + assert isinstance(info, dict) # type: ignore