Skip to content

Commit

Permalink
Merge pull request #32 from latorc/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
latorc committed Apr 27, 2024
2 parents 1a4e981 + 75e898a commit 1d0fc88
Show file tree
Hide file tree
Showing 30 changed files with 674 additions and 364 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ MahjongCopilot.spec
/settings*.json
libriichi3p/*.pyd
libriichi3p/*.so
chrome_ext/
2 changes: 1 addition & 1 deletion bot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
""" mjai protocol bot implementations"""
from .common import *
from .factory import *
from .bot import *
35 changes: 35 additions & 0 deletions bot/akagiot/bot_akagiot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
""" Bot for Akagi online-trained API"""
import requests
from common.log_helper import LOGGER
from bot.bot import BotMjai, GameMode
from bot.akagiot.engine import MortalEngineAkagiOt


class BotAkagiOt(BotMjai):
""" Bot implementation for Akagi online-trained API """

def __init__(self, url:str, apikey:str) -> None:
super().__init__("Akagi OT API Bot")
self.url = url
self.apikey = apikey
self._check()

def _check(self):
# check authorization
headers = {
'Authorization': self.apikey,
}
r = requests.post(f"{self.url}/check", headers=headers, timeout=5)
r_json = r.json()
if r_json["result"] == "success":
LOGGER.info("Akagi OT API check success")

@property
def supported_modes(self) -> list[GameMode]:
""" return suported game modes"""
return [GameMode.MJ4P, GameMode.MJ3P]


def _get_engine(self, mode: GameMode):
engine = MortalEngineAkagiOt(self.apikey, self.url, mode)
return engine
126 changes: 126 additions & 0 deletions bot/akagiot/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
""" Engine for Akagi OT model"""

import json
import gzip
import requests

from common.log_helper import LOGGER
from common.utils import BotNotSupportingMode, GameMode


class MortalEngineAkagiOt:
""" Mortal Engine for Akagi OT"""
def __init__(
self,
api_key:str = None, server:str = None,
mode:GameMode=GameMode.MJ4P,
timeout:int=3, retries:int=3):

self.name = "MortalEngineAkagiOt"
self.is_oracle = False
self.version = 4
self.enable_quick_eval = False
self.enable_rule_based_agari_guard = False

self.api_key = api_key
self.server = server
self.mode = mode
self.timeout = timeout
self.retries = retries

if self.mode == GameMode.MJ4P:
self.api_path = r"/react_batch"
elif self.mode == GameMode.MJ3P:
self.api_path = r"/react_batch_3p"
else:
raise BotNotSupportingMode(self.mode)


def react_batch(self, obs, masks, _invisible_obs):
""" react_batch for mjai.Bot to call"""
list_obs = [o.tolist() for o in obs]
list_masks = [m.tolist() for m in masks]
post_data = {
'obs': list_obs,
'masks': list_masks,
}
data = json.dumps(post_data, separators=(',', ':'))
compressed_data = gzip.compress(data.encode('utf-8'))
headers = {
'Authorization': self.api_key,
'Content-Encoding': 'gzip',
}

# retry multiple times to post and get response
for attempt in range(self.retries):
try:
r = requests.post(f'{self.server}{self.api_path}',
headers=headers,
data=compressed_data,
timeout=self.timeout)
except requests.exceptions.Timeout:
LOGGER.warning("AkagiOT api timeout, retry %d/%d", attempt+1, self.retries)
r = None
continue

if r is None:
raise RuntimeError("AkagiOT API all retries failed.")

if r.status_code != 200:
r.raise_for_status()
r_json = r.json()
return r_json['actions'], r_json['q_out'], r_json['masks'], r_json['is_greedy']

# Mortal Engine Parameters:
#
## boltzmann_temp:
# 1
# brain:
# Brain(
# (encoder): ResNet(
# (net): Sequential(
# (0): Conv1d(1012, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
# (1): ResBlock(
# (res_unit): Sequential(
# (0): BatchNorm1d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
# (1): Mish(inplace=True)
# (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
# (3): BatchNorm1d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
# (4): Mish(inplace=True)
# (5): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
# )
# (ca): ChannelAttention(
# (shared_mlp): Sequential(
# (0): Linear(in_features=256, out_features=16, bias=True)
# (1): Mish(inplace=True)
# (2): Linear(in_features=16, out_features=256, bias=True)
# )
# )
# )
# (2): ResBlock(
# (res_unit): Sequential(
# (0): BatchNorm1d(256, eps=0.001, momentum=0.01, a...
# device:
# device(type='cpu')
# dqn:
# DQN(
# (net): Linear(in_features=1024, out_features=47, bias=True)
# )
# enable_amp:
# False
# enable_quick_eval:
# False
# enable_rule_based_agari_guard:
# False
# engine_type:
# 'mortal'
# is_oracle:
# False
# name:
# 'mortal'
# stochastic_latent:
# False
# top_p:
# 1
# version:
# 4
80 changes: 71 additions & 9 deletions bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
""" Bot represents a mjai protocol bot
implement wrappers for supportting different bot types
"""
from enum import Enum
import json
from abc import ABC, abstractmethod

import common.mj_helper as mj_helper
from common.log_helper import LOGGER
from common.mj_helper import meta_to_options, MjaiType
from common.utils import GameMode, BotNotSupportingMode

class BotType(Enum):
""" Model type for bot"""
LOCAL = "Local"
MJAPI = "MJAPI"

def reaction_convert_meta(reaction:dict, is_3p:bool=False):
""" add meta_options to reaction """
if 'meta' in reaction:
meta = reaction['meta']
reaction['meta_options'] = mj_helper.meta_to_options(meta, is_3p)
reaction['meta_options'] = meta_to_options(meta, is_3p)

class Bot(ABC):
""" Bot Interface class
Expand All @@ -26,8 +23,7 @@ class Bot(ABC):
which is a 'dahai' msg, representing the subsequent dahai action after reach
"""

def __init__(self, bot_type:BotType, name:str="Bot") -> None:
self.type = bot_type
def __init__(self, name:str="Bot") -> None:
self.name = name
self._initialized:bool = False
self.seat:int = None
Expand Down Expand Up @@ -78,3 +74,69 @@ def react_batch(self, input_list:list[dict]) -> dict | None:
last_reaction = self.react(input_list[-1])
return last_reaction


class BotMjai(Bot):
""" base class for libriichi.mjai Bots"""
def __init__(self, name:str) -> None:
super().__init__(name)

self.mjai_bot = None
self.ignore_next_turn_self_reach:bool = False


@property
def info_str(self) -> str:
return f"{self.name}: [{','.join([m.value for m in self.supported_modes])}]"


def _get_engine(self, mode:GameMode):
# return MortalEngine object
raise NotImplementedError("Subclass must implement this method")


def _init_bot_impl(self, mode:GameMode=GameMode.MJ4P):
engine = self._get_engine(mode)
if not engine:
raise BotNotSupportingMode(mode)
if mode == GameMode.MJ4P:
try:
import libriichi
except:
import riichi as libriichi
self.mjai_bot = libriichi.mjai.Bot(engine, self.seat)
elif mode == GameMode.MJ3P:
import libriichi3p
self.mjai_bot = libriichi3p.mjai.Bot(engine, self.seat)
else:
raise BotNotSupportingMode(mode)


def react(self, input_msg:dict) -> dict:
if self.mjai_bot is None:
return None
if self.ignore_next_turn_self_reach: # ignore repetitive self reach. only for the very next msg
if input_msg['type'] == MjaiType.REACH and input_msg['actor'] == self.seat:
LOGGER.debug("Ignoring repetitive self reach msg, reach msg already sent to AI last turn")
return None
self.ignore_next_turn_self_reach = False

str_input = json.dumps(input_msg)

react_str = self.mjai_bot.react(str_input)
if react_str is None:
return None
reaction = json.loads(react_str)
# Special treatment for self reach output msg
# mjai only outputs dahai msg after the reach msg
if reaction['type'] == MjaiType.REACH and reaction['actor'] == self.seat: # Self reach
# get the subsequent dahai message,
# appeding it to the reach reaction msg as 'reach_dahai' key
LOGGER.debug("Send reach msg to get reach_dahai. Cannot go back to unreach!")
# TODO make a clone of mjai_bot so reach can be tested to get dahai without affecting the game

reach_msg = {'type': MjaiType.REACH, 'actor': self.seat}
reach_dahai_str = self.mjai_bot.react(json.dumps(reach_msg))
reach_dahai = json.loads(reach_dahai_str)
reaction['reach_dahai'] = reach_dahai
self.ignore_next_turn_self_reach = True # ignore very next reach msg
return reaction
108 changes: 0 additions & 108 deletions bot/bot_local.py

This file was deleted.

Loading

0 comments on commit 1d0fc88

Please sign in to comment.