Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add include_hash_inv arg to ChessEnv #2766

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3709,6 +3709,74 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
if include_san:
assert "san_hash" in env.observation_spec.keys()

# Test that `include_hash_inv=True` allows us to specify the board state
# with just the "fen_hash" or "pgn_hash", not "fen" or "pgn", when taking a
# step in the env.
@pytest.mark.parametrize(
"include_fen,include_pgn",
[[True, False], [False, True]],
)
@pytest.mark.parametrize("stateful", [True, False])
def test_env_hash_inv(self, include_fen, include_pgn, stateful):
env = ChessEnv(
include_fen=include_fen,
include_pgn=include_pgn,
include_hash=True,
include_hash_inv=True,
stateful=stateful,
)
env.check_env_specs()

def exclude_fen_and_pgn(td):
td = td.exclude("fen")
td = td.exclude("pgn")
return td

td0 = env.reset()

if include_fen:
env_check_fen = ChessEnv(
include_fen=True,
stateful=stateful,
)

if include_pgn:
env_check_pgn = ChessEnv(
include_pgn=True,
stateful=stateful,
)

for _ in range(8):
td1 = env.rand_step(exclude_fen_and_pgn(td0.clone()))

# Confirm that fen/pgn was not used to determine the board state
assert "fen" not in td1.keys()
assert "pgn" not in td1.keys()

if include_fen:
assert (td1["fen_hash"] == td0["fen_hash"]).all()
assert "fen" in td1["next"]

# Check that if we start in the same board state and perform the
# same action in an env that does not use hashes, we obtain the
# same next board state. This confirms that we really can
# successfully specify the board state with a hash.
td0_check = td1.clone().exclude("next").update({"fen": td0["fen"]})
assert (
env_check_fen.step(td0_check)["next", "fen"] == td1["next", "fen"]
)

if include_pgn:
assert (td1["pgn_hash"] == td0["pgn_hash"]).all()
assert "pgn" in td1["next"]

td0_check = td1.clone().exclude("next").update({"pgn": td0["pgn"]})
assert (
env_check_pgn.step(td0_check)["next", "pgn"] == td1["next", "pgn"]
)

td0 = td1["next"]

@pytest.mark.skipif(not _has_tv, reason="torchvision not found.")
@pytest.mark.skipif(not _has_cairosvg, reason="cairosvg not found.")
@pytest.mark.parametrize("stateful", [False, True])
Expand Down
40 changes: 29 additions & 11 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,38 @@
class _ChessMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if kwargs.get("include_hash"):
include_hash = kwargs.get("include_hash")
include_hash_inv = kwargs.get("include_hash_inv")
if include_hash:
from torchrl.envs import Hash

in_keys = []
out_keys = []
if instance.include_san:
in_keys.append("san")
out_keys.append("san_hash")
if instance.include_fen:
in_keys.append("fen")
out_keys.append("fen_hash")
if instance.include_pgn:
in_keys.append("pgn")
out_keys.append("pgn_hash")
instance = instance.append_transform(Hash(in_keys, out_keys))
in_keys_inv = [] if include_hash_inv else None
out_keys_inv = [] if include_hash_inv else None

def maybe_add_keys(condition, in_key, out_key):
if condition:
in_keys.append(in_key)
out_keys.append(out_key)
if include_hash_inv:
in_keys_inv.append(in_key)
out_keys_inv.append(out_key)

maybe_add_keys(instance.include_san, "san", "san_hash")
maybe_add_keys(instance.include_fen, "fen", "fen_hash")
maybe_add_keys(instance.include_pgn, "pgn", "pgn_hash")

instance = instance.append_transform(
Hash(in_keys, out_keys, in_keys_inv, out_keys_inv)
)
elif include_hash_inv:
raise ValueError(
(
"'include_hash_inv=True' can only be set if"
f"'include_hash=True', but got 'include_hash={include_hash}'."
)
)
if kwargs.get("mask_actions", True):
from torchrl.envs import ActionMask

Expand Down Expand Up @@ -265,6 +282,7 @@ def __init__(
include_pgn: bool = False,
include_legal_moves: bool = False,
include_hash: bool = False,
include_hash_inv: bool = False,
mask_actions: bool = True,
pixels: bool = False,
):
Expand Down
Loading