Skip to content

Commit b8b8155

Browse files
committed
Add Nexto & its toxic variant
1 parent 72e8c03 commit b8b8155

11 files changed

+1032
-1
lines changed

tests/atba/atba.bot.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fun_fact = "This is a test bot"
2121
source_link = "https://github.com/RLBot/RLBot"
2222
developer = "BotMaker"
2323
language = "Python 3"
24-
# ALL POSSIBLE TAGS: 1v1, teamplay, goalie, hoops, dropshot, snow-day, spike-rush, heatseaker, memebot
24+
# ALL POSSIBLE TAGS: 1v1, teamplay, goalie, hoops, dropshot, snow-day, spike-rush, heatseeker, memebot
2525
# NOTE: Only add the goalie tag if your bot only plays as a goalie; this directly contrasts with the teamplay tag!
2626
# NOTE: Only add a tag for a special game mode if you bot properly supports it
2727
tags = []

tests/nexto/agent.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import math
2+
import os
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn.functional as F
7+
from torch.distributions import Categorical
8+
9+
10+
def make_lookup_table():
11+
actions = []
12+
# Ground
13+
for throttle in (-1, 0, 1):
14+
for steer in (-1, 0, 1):
15+
for boost in (0, 1):
16+
for handbrake in (0, 1):
17+
if boost == 1 and throttle != 1:
18+
continue
19+
actions.append(
20+
[throttle or boost, steer, 0, steer, 0, 0, boost, handbrake]
21+
)
22+
# Aerial
23+
for pitch in (-1, 0, 1):
24+
for yaw in (-1, 0, 1):
25+
for roll in (-1, 0, 1):
26+
for jump in (0, 1):
27+
for boost in (0, 1):
28+
if jump == 1 and yaw != 0: # Only need roll for sideflip
29+
continue
30+
if pitch == roll == jump == 0: # Duplicate with ground
31+
continue
32+
# Enable handbrake for potential wavedashes
33+
handbrake = jump == 1 and (pitch != 0 or yaw != 0 or roll != 0)
34+
actions.append(
35+
[boost, yaw, pitch, yaw, roll, jump, boost, handbrake]
36+
)
37+
actions = np.array(actions)
38+
return actions
39+
40+
41+
class Agent:
42+
_lookup_table = make_lookup_table()
43+
state = None
44+
45+
def __init__(self):
46+
cur_dir = os.path.dirname(os.path.realpath(__file__))
47+
with open(os.path.join(cur_dir, "nexto-model.pt"), "rb") as f:
48+
self.actor = torch.jit.load(f)
49+
torch.set_num_threads(1)
50+
51+
def act(self, state, beta):
52+
state = tuple(torch.from_numpy(s).float() for s in state)
53+
54+
with torch.no_grad():
55+
out, weights = self.actor(state)
56+
self.state = state
57+
58+
out = (out,)
59+
max_shape = max(o.shape[-1] for o in out)
60+
logits = torch.stack(
61+
[
62+
(
63+
l
64+
if l.shape[-1] == max_shape
65+
else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf"))
66+
)
67+
for l in out
68+
],
69+
dim=1,
70+
)
71+
72+
# beta = 0.5
73+
if beta == 1:
74+
actions = np.argmax(logits, axis=-1)
75+
elif beta == -1:
76+
actions = np.argmin(logits, axis=-1)
77+
else:
78+
if beta == 0:
79+
logits[torch.isfinite(logits)] = 0
80+
else:
81+
logits *= math.log((beta + 1) / (1 - beta), 3)
82+
dist = Categorical(logits=logits)
83+
actions = dist.sample()
84+
85+
# print(Categorical(logits=logits).sample())
86+
parsed = self._lookup_table[actions.numpy().item()]
87+
88+
return parsed, weights

0 commit comments

Comments
 (0)