Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve gamemode weighting #34

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
42 changes: 28 additions & 14 deletions rocket_learn/rollout_generator/redis/redis_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import os
import time
import copy
from threading import Thread
from uuid import uuid4

Expand Down Expand Up @@ -43,6 +44,7 @@ class RedisRolloutWorker:
:param auto_minimize: automatically minimize the launched rocket league instance
:param local_cache_name: name of local database used for model caching. If None, caching is not used
:param gamemode_weights: dict of dynamic gamemode choice weights. If None, default equal experience
:param gamemode_weight_ema_alpha: alpha for the exponential moving average of gamemode weighting
"""

def __init__(self, redis: Redis, name: str, match: Match,
Expand All @@ -51,7 +53,9 @@ def __init__(self, redis: Redis, name: str, match: Match,
send_obs=True, scoreboard=None, pretrained_agents=None,
human_agent=None, force_paging=False, auto_minimize=True,
local_cache_name=None,
gamemode_weights=None,):
gamemode_weights=None,
gamemode_weight_ema_alpha=0.02,
):
# TODO model or config+params so workers can recreate just from redis connection?
self.redis = redis
self.name = name
Expand Down Expand Up @@ -80,8 +84,18 @@ def __init__(self, redis: Redis, name: str, match: Match,
self.send_obs = send_obs
self.dynamic_gm = dynamic_gm
self.gamemode_weights = gamemode_weights
if self.gamemode_weights is not None:
assert sum(self.gamemode_weights.values()) == 1, "gamemode_weights must sum to 1"
if self.gamemode_weights is None:
self.gamemode_weights = {'1v1': 1/3, '2v2': 1/3, '3v3': 1/3}
assert sum(self.gamemode_weights.values()) == 1, "gamemode_weights must sum to 1"
self.target_weights = copy.copy(self.gamemode_weights)
# change weights from percentage of experience desired to percentage of gamemodes necessary (approx)
self.current_weights = copy.copy(self.gamemode_weights)
for k in self.current_weights.keys():
b, o = k.split("v")
self.current_weights[k] /= int(b)
self.current_weights = {k: self.current_weights[k] / (sum(self.current_weights.values()) + 1e-8) for k in self.current_weights.keys()}
self.mean_exp_grant = {'1v1': 1000, '2v2': 2000, '3v3': 3000}
self.ema_alpha = gamemode_weight_ema_alpha
self.local_cache_name = local_cache_name

self.uuid = str(uuid4())
Expand Down Expand Up @@ -210,19 +224,15 @@ def _get_past_model(self, version):
return model

def select_gamemode(self):
mode_exp = {m.decode("utf-8"): int(v) for m, v in self.redis.hgetall(EXPERIENCE_PER_MODE).items()}
if self.gamemode_weights is None:
mode = min(mode_exp, key=mode_exp.get)
else:
total = sum(mode_exp.values()) + 1e-8
mode_exp = {k: mode_exp[k] / total for k in mode_exp.keys()}
# find exp which is farthest below desired exp
diff = {k: self.gamemode_weights[k] - mode_exp[k] for k in mode_exp.keys()}
mode = max(diff, key=diff.get)

emp_weight = {k: self.mean_exp_grant[k] / (sum(self.mean_exp_grant.values()) + 1e-8)
for k in self.mean_exp_grant.keys()}
cor_weight = {k: self.gamemode_weights[k] / emp_weight[k] for k in self.gamemode_weights.keys()}
self.current_weights = {k: cor_weight[k] / (sum(cor_weight.values()) + 1e-8) for k in cor_weight}
mode = np.random.choice(list(self.current_weights.keys()), p=list(self.current_weights.values()))
b, o = mode.split("v")
return int(b), int(o)


def run(self): # Mimics Thread
"""
begin processing in already launched match and push to redis
Expand Down Expand Up @@ -310,7 +320,11 @@ def run(self): # Mimics Thread
state = rollouts[0].infos[-2]["state"]
goal_speed = np.linalg.norm(state.ball.linear_velocity) * 0.036 # kph
str_result = ('+' if result > 0 else "") + str(result)
self.total_steps_generated += len(rollouts[0].observations) * len(rollouts)
episode_exp = len(rollouts[0].observations) * len(rollouts)
self.total_steps_generated += episode_exp
if self.dynamic_gm:
old_exp = self.mean_exp_grant[f"{blue}v{orange}"]
self.mean_exp_grant[f"{blue}v{orange}"] = ((episode_exp - old_exp) * self.ema_alpha) + old_exp
post_stats = f"Rollout finished after {len(rollouts[0].observations)} steps ({self.total_steps_generated} total steps), result was {str_result}"
if result != 0:
post_stats += f", goal speed: {goal_speed:.2f} kph"
Expand Down