Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions envs/connect4_env/rubrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Trajectory-based rubrics for Connect4 environment.

This module demonstrates the TrajectoryRubric pattern from RFC 004
for terminal games where reward signals depend on game outcome.

Connect4 is ideal for demonstrating trajectory rubrics because:
- Win/loss is only known at game end
- Clear semantics: 1.0 for win, 0.0 for loss, 0.5 for draw
- Discounting can assign more credit to decisive late-game moves
"""

from typing import Any, Dict, List, Tuple

from openenv.core.rubrics import ExponentialDiscountingTrajectoryRubric


class Connect4WinLossRubric(ExponentialDiscountingTrajectoryRubric):
"""Trajectory rubric that scores Connect4 games based on outcome.

Scores:
- 1.0 for win (player made 4 in a row)
- 0.0 for loss (opponent made 4 in a row, or player made invalid move)
- 0.5 for draw (board full, no winner)

With exponential discounting, later moves (closer to the decisive
outcome) receive higher rewards. This helps credit assignment:
the move that completes 4-in-a-row gets the most credit.

Usage:
rubric = Connect4WinLossRubric(gamma=0.95)
env = Connect4Environment(rubric=rubric)

obs = env.reset()
while not obs.done:
action = agent.act(obs)
obs = env.step(action)

# Get per-step rewards for training
step_rewards = rubric.compute_step_rewards()
# step_rewards[i] = gamma^(T-1-i) * final_score
"""

def __init__(
self,
gamma: float = 0.95,
invalid_move_penalty: float = 0.0,
player_id: int = 1,
):
"""Initialize Connect4 trajectory rubric.

Args:
gamma: Discount factor for credit assignment. 0.95 gives
more credit to later (decisive) moves.
invalid_move_penalty: Score when player makes invalid move.
Default 0.0 (treat as loss).
player_id: Which player we're scoring for (1 or -1).
"""
super().__init__(gamma=gamma, intermediate_reward=0.0)
self.invalid_move_penalty = invalid_move_penalty
self.player_id = player_id

def score_trajectory(self, trajectory: List[Tuple[Any, Any]]) -> float:
"""Score based on game outcome.

Returns:
1.0 for win, 0.0 for loss, 0.5 for draw.
"""
if not trajectory:
return 0.0

_, final_obs = trajectory[-1]

# Check for done observation
if not getattr(final_obs, "done", False):
return 0.0

# Get reward from observation
reward = getattr(final_obs, "reward", 0.0)
if reward is None:
reward = 0.0

# Interpret reward:
# -1 = invalid move (loss)
# 1.0 = win
# 0.0 with done = draw
if reward == -1:
return self.invalid_move_penalty
elif reward == 1.0:
return 1.0
elif reward == 0.0:
# Draw (board full, no winner)
return 0.5
else:
# Unexpected value, treat as loss
return 0.0

def state_dict(self) -> Dict[str, Any]:
"""Serialize configuration."""
state = super().state_dict()
state["invalid_move_penalty"] = self.invalid_move_penalty
state["player_id"] = self.player_id
return state

def load_state_dict(self, state: Dict[str, Any]) -> None:
"""Load configuration from checkpoint."""
super().load_state_dict(state)
if "invalid_move_penalty" in state:
self.invalid_move_penalty = state["invalid_move_penalty"]
if "player_id" in state:
self.player_id = state["player_id"]


__all__ = [
"Connect4WinLossRubric",
]
61 changes: 50 additions & 11 deletions envs/connect4_env/server/connect4_environment.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,63 @@
import uuid
from typing import Optional
import numpy as np
from openenv.core.env_server import Environment
from openenv.core.rubrics import Rubric

from ..models import Connect4Action, Connect4Observation, Connect4State


class Connect4Environment(Environment):
"""Connect4 game environment with optional rubric-based scoring.

This environment demonstrates the rubric integration pattern from RFC 004.
When a rubric is provided (e.g., Connect4WinLossRubric), it can be used
for trajectory-based reward computation for training.

Usage without rubric:
env = Connect4Environment()
obs = env.step(action) # reward is computed inline

Usage with trajectory rubric:
from connect4_env.rubrics import Connect4WinLossRubric

rubric = Connect4WinLossRubric(gamma=0.95)
env = Connect4Environment(rubric=rubric)

obs = env.reset()
while not obs.done:
action = agent.act(obs)
obs = env.step(action)

# Get per-step rewards with discounting for training
step_rewards = env.rubric.compute_step_rewards()
"""

ROWS = 6
COLUMNS = 7

def __init__(self, opponent=None):
super().__init__()
def __init__(
self,
opponent=None,
rubric: Optional[Rubric] = None,
):
super().__init__(rubric=rubric)
self._opponent = opponent
self.reset()

def reset(self):
def reset(self, seed=None, episode_id=None, **kwargs):
# Reset rubric state for new episode (RFC 004)
self._reset_rubric()

self.board = np.zeros((self.ROWS, self.COLUMNS), dtype=np.int8)
self.next_player = 1
self.invalid_move_played = False

self._state = Connect4State(
board=self.board.copy().tolist(),
next_player=self.next_player,
episode_id=str(uuid.uuid4()),
step_count=0
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
)
return self._make_observation()

Expand All @@ -47,12 +82,12 @@ def step(self, action: Connect4Action):
reward, done = self._check_win_or_draw(row, col)

self.next_player *= -1

self._state = Connect4State(
board=self.board.copy().tolist(),
next_player=self.next_player,
episode_id=self._state.episode_id,
step_count=self._state.step_count + 1
step_count=self._state.step_count + 1,
)

return self._make_observation(reward, done)
Expand All @@ -64,18 +99,22 @@ def _make_observation(self, reward=0.0, done=False):
legal_actions=legal_actions,
reward=reward,
done=done,
metadata={"next_player": self.next_player}
metadata={"next_player": self.next_player},
)

def _check_win_or_draw(self, row, col):
# Implement 4-in-a-row check (like your Gymnasium code)
player = self.board[row, col]
directions = [(1,0),(0,1),(1,1),(1,-1)]
directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
for dr, dc in directions:
count = 0
for step in range(-3, 4):
r, c = row + step*dr, col + step*dc
if 0 <= r < self.ROWS and 0 <= c < self.COLUMNS and self.board[r,c] == player:
r, c = row + step * dr, col + step * dc
if (
0 <= r < self.ROWS
and 0 <= c < self.COLUMNS
and self.board[r, c] == player
):
count += 1
if count >= 4:
return 1.0, True
Expand Down
Loading