Skip to content

Commit

Permalink
refactor: #73 default use builtin random as generator
Browse files Browse the repository at this point in the history
  • Loading branch information
zyr17 committed Feb 23, 2024
1 parent c4222b7 commit f412112
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 15 deletions.
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ classifiers = [
"Programming Language :: Python :: 3.10",
]
dependencies = [
"numpy",
"pydantic==1.10.14",
"typing_extensions",
"dictdiffer",
Expand All @@ -25,7 +24,15 @@ dependencies = [
]

[project.optional-dependencies]
dev = ["setuptools-scm", "pytest", "pytest-cov", "pytest-xdist", "pyright", "build"]
dev = [
"build",
"numpy",
"setuptools-scm",
"pytest",
"pytest-cov",
"pytest-xdist",
"pyright",
]

[project.urls]
"Homepage" = "https://github.com/LPSim/backend"
Expand Down
11 changes: 9 additions & 2 deletions src/lpsim/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
from typing import Any
from pydantic import PrivateAttr
from .agent_base import AgentBase
from ..server.match import Match
Expand All @@ -23,14 +23,20 @@
from ..server.consts import DieColor


try:
import numpy as np
except ImportError:
print("numpy is not installed, RandomAgent will not work.")


class RandomAgent(AgentBase):
"""
Agent that randomly choose one type of requests, randomly choose one of
selected type, and randomly choose one of available options.
"""

random_seed: int | None = None
_random_state: np.random.RandomState = PrivateAttr(np.random.RandomState())
_random_state: Any = PrivateAttr(None)
random_state_set: bool = False

def json(self, *argv, **kwargs):
Expand All @@ -39,6 +45,7 @@ def json(self, *argv, **kwargs):
def random(self) -> float:
"""
Return a random float between 0 and 1.
TODO: should save random state after random call to ensure reproducibility.
"""
if not self.random_state_set:
self._random_state = np.random.RandomState()
Expand Down
3 changes: 2 additions & 1 deletion src/lpsim/network/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def get_new_match(
Args:
decks: The decks of players. If its length is zero, will not set decks
or start the match.
seed: The random seed. It should follow the format of numpy.random.
seed: The random seed. It should follow the format of
numpy.RandomState.get_state(legacy=True) or random.Random.getstate().
rich_mode: If True, use rich mode, at round start, players is given
16 omni dice. Mainly used in code testing.
match_config: The config of the match. If None, use default config.
Expand Down
51 changes: 41 additions & 10 deletions src/lpsim/server/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import logging
import copy
import numpy as np
import random
from typing import Literal, List, Any, Dict, Tuple
from enum import Enum
from pydantic import PrivateAttr, validator
Expand Down Expand Up @@ -130,6 +130,15 @@
from .event_handler import SystemEventHandlerBase, SystemEventHandler


try:
import numpy as np
except ImportError: # pragma: no cover
# in legacy version, we use numpy as random generator. And in newer version, we
# use random as random generator. If numpy is not installed, we cannot set random
# state with numpy version states.
pass


class MatchState(str, Enum):
"""
Enum representing the state of a match.
Expand Down Expand Up @@ -332,7 +341,7 @@ class Match(BaseModel):

# random state
random_state: List[Any] = []
_random_state: np.random.RandomState = PrivateAttr(np.random.RandomState())
_random_state: Any = PrivateAttr(None)

# event handlers to implement special rules.
event_handlers: List[SystemEventHandlerBase] = [
Expand Down Expand Up @@ -433,13 +442,26 @@ def _init_random_state(self):
# no need to init random state
return
if self.random_state:
random_state = self.random_state[:]
random_state[1] = np.array(random_state[1], dtype="uint32")
self._random_state.set_state(random_state) # type: ignore
if self.random_state[0] == "MT19937":
assert "np" in globals(), (
"numpy is not installed, cannot set random state with numpy "
"version states."
)
random_state = self.random_state[:]
random_state[1] = np.array(random_state[1], dtype="uint32")
self._random_state = np.random.RandomState()
self._random_state.set_state(random_state)
else:
random_state = (
self.random_state[0],
tuple(self.random_state[1]),
self.random_state[2],
)
self._random_state = random.Random()
self._random_state.setstate(random_state)
else:
# random state not set, re-new self._random_state to avoid
# affecting other matches.
self._random_state = np.random.RandomState()
# random state not set, create new random state
self._random_state = random.Random()
self._save_random_state()

def _save_history(self) -> None:
Expand Down Expand Up @@ -577,8 +599,17 @@ def _save_random_state(self):
"""
Save the random state.
"""
self.random_state = list(self._random_state.get_state(legacy=True))
self.random_state[1] = self.random_state[1].tolist()
if isinstance(self._random_state, random.Random):
self.random_state = list(self._random_state.getstate())
self.random_state[1] = list(self.random_state[1])
return
elif isinstance(self._random_state, np.random.RandomState):
self.random_state = list(self._random_state.get_state(legacy=True))
self.random_state[1] = self.random_state[1].tolist()
else:
raise AssertionError(
f"Random state type {type(self._random_state)} not recognized."
)

def _random(self):
"""
Expand Down
40 changes: 40 additions & 0 deletions tests/server/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,45 @@ def test_use_card_event_serialize():
assert card_event.action.card_position.area == ObjectPositionType.HAND


def test_different_random_state():
deck = Deck.from_str(
"""
default_version:4.0
character:Rhodeia of Loch
character:Kamisato Ayaka
Traveler's Handy Sword*5
Gambler's Earrings*5
Kanten Senmyou Blessing*5
Sweet Madame*5
Abyssal Summons*5
Fatui Conspiracy*5
Timmie*5
"""
)
# numpy random state
match = Match(random_state=get_random_state())
match.set_deck([deck, deck])
match.config.max_same_card_number = 30
match.config.card_number = None
match.config.character_number = None
match.config.check_deck_restriction = False
assert match.start()[0]
assert match.random_state[0] == "MT19937"
assert match.random_state[3] == 0
assert match.random_state[4] == 0

# python random state
match = Match()
match.set_deck([deck, deck])
match.config.max_same_card_number = 30
match.config.card_number = None
match.config.character_number = None
match.config.check_deck_restriction = False
assert match.start()[0]
assert match.random_state[0] == 3
assert match.random_state[2] is None


if __name__ == "__main__":
# test_match_pipeline()
# test_save_load()
Expand All @@ -1377,3 +1416,4 @@ def test_use_card_event_serialize():
# test_version_validation()
# test_round_end_all_lose()
test_use_card_event_serialize()
test_different_random_state()

0 comments on commit f412112

Please sign in to comment.