Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuhu-kth committed May 30, 2024
1 parent 540ebe9 commit 2b996bf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_is_state_terminal(self) -> None:
assert not env.is_state_terminal(state_tuple)

with pytest.raises(ValueError):
env.is_state_terminal([1, 2, 3]) # type: ignore
env.is_state_terminal([1, 2, 3]) # type: ignore

def test_get_observation_from_history(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from csle_common.dao.training.policy import Policy
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.player_type import PlayerType
from csle_common.dao.simulation_config.action import Action
import pytest
from unittest.mock import MagicMock
import numpy as np
Expand Down Expand Up @@ -146,7 +147,7 @@ def test_set_state(self) -> None:
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
assert not env.set_state(1) # type: ignore
assert not env.set_state(1) # type: ignore

def test_calculate_stage_policy(self) -> None:
"""
Expand Down Expand Up @@ -192,7 +193,7 @@ def test_get_attacker_dist(self) -> None:
def test_render(self) -> None:
"""
Tests the function for rendering the environment
:return: None
"""
defender_strategy = MagicMock(spec=Policy)
Expand Down Expand Up @@ -328,26 +329,32 @@ def test_get_actions_from_particles(self) -> None:
def test_step(self) -> None:
"""
Tests the function for taking a step in the environment by executing the given action
:return: None
"""
defender_stage_strategy = np.zeros((3, 2))
defender_stage_strategy[0][0] = 0.9
defender_stage_strategy[0][1] = 0.1
defender_stage_strategy[1][0] = 0.9
defender_stage_strategy[1][1] = 0.1
defender_stage_strategy[2] = defender_stage_strategy[1]
defender_strategy = RandomPolicy(actions=list(self.config.A1), player_type=PlayerType.DEFENDER,
stage_policy_tensor=list(defender_stage_strategy))
defender_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A1))
defender_strategy = RandomPolicy(
actions=defender_actions,
player_type=PlayerType.DEFENDER,
stage_policy_tensor=list(defender_stage_strategy),
)
attacker_mdp_config = StoppingGameAttackerMdpConfig(
env_name="test_env",
stopping_game_config=self.config,
defender_strategy=defender_strategy,
stopping_game_name="csle-stopping-game-v1",
)

env = StoppingGameMdpAttackerEnv(config=attacker_mdp_config)
env.reset()
pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs),a2=0)
attacker_obs, r[1], d, d, info = env.step(pi2)
assert isinstance(attacker_obs[0], float) # type: ignore
pi2 = env.calculate_stage_policy(o=list(env.latest_attacker_obs), a2=0) # type: ignore
attacker_obs, reward, terminated, truncated, info = env.step(pi2)
assert isinstance(attacker_obs[0], float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(info, dict) # type: ignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import StoppingGamePomdpDefenderEnv
from gym_csle_stopping_game.envs.stopping_game_pomdp_defender_env import (
StoppingGamePomdpDefenderEnv,
)
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import (
StoppingGameDefenderPomdpConfig,
)
from gym_csle_stopping_game.envs.stopping_game_env import StoppingGameEnv
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from csle_common.dao.training.policy import Policy
from csle_common.dao.simulation_config.action import Action
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.player_type import PlayerType
import pytest
Expand Down Expand Up @@ -219,7 +224,7 @@ def test_set_state(self) -> None:
stopping_game_name="csle-stopping-game-v1",
)
env = StoppingGamePomdpDefenderEnv(config=defender_pomdp_config)
assert env.set_state(1) is None # type: ignore
assert env.set_state(1) is None # type: ignore

def test_get_observation_from_history(self) -> None:
"""
Expand Down Expand Up @@ -301,7 +306,10 @@ def test_get_actions_from_particles(self) -> None:
t = 0
observation = 0
expected_actions = [0, 1]
assert env.get_actions_from_particles(particles, t, observation) == expected_actions
assert (
env.get_actions_from_particles(particles, t, observation)
== expected_actions
)

def test_step(self) -> None:
"""
Expand All @@ -315,8 +323,12 @@ def test_step(self) -> None:
attacker_stage_strategy[1][0] = 0.9
attacker_stage_strategy[1][1] = 0.1
attacker_stage_strategy[2] = attacker_stage_strategy[1]
attacker_strategy = RandomPolicy(actions=list(self.config.A2), player_type=PlayerType.ATTACKER,
stage_policy_tensor=list(attacker_stage_strategy))
attacker_actions = list(map(lambda x: Action(id=x, descr=""), self.config.A2))
attacker_strategy = RandomPolicy(
actions=attacker_actions,
player_type=PlayerType.ATTACKER,
stage_policy_tensor=list(attacker_stage_strategy),
)
defender_pomdp_config = StoppingGameDefenderPomdpConfig(
env_name="test_env",
stopping_game_config=self.config,
Expand All @@ -328,9 +340,9 @@ def test_step(self) -> None:
env.reset()
defender_obs, reward, terminated, truncated, info = env.step(a1)
assert len(defender_obs) == 2
assert isinstance(defender_obs[0], float) # type: ignore
assert isinstance(defender_obs[1], float) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(info, dict) # type: ignore
assert isinstance(defender_obs[0], float) # type: ignore
assert isinstance(defender_obs[1], float) # type: ignore
assert isinstance(reward, float) # type: ignore
assert isinstance(terminated, bool) # type: ignore
assert isinstance(truncated, bool) # type: ignore
assert isinstance(info, dict) # type: ignore

0 comments on commit 2b996bf

Please sign in to comment.