Skip to content

Commit

Permalink
further test additions, added enable_shared_from_this to istatetree
Browse files Browse the repository at this point in the history
  • Loading branch information
maichmueller committed Apr 21, 2023
1 parent 3da130d commit 7fbb749
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 27 deletions.
3 changes: 2 additions & 1 deletion open_spiel/algorithms/infostate_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ absl::optional<DecisionId> InfostateTree::DecisionIdForSequence(
}
}
absl::optional<InfostateNode*> InfostateTree::DecisionForSequence(
const SequenceId& sequence_id) {
const SequenceId& sequence_id) const
{
SPIEL_DCHECK_TRUE(sequence_id.BelongsToTree(this));
InfostateNode* node = sequences_.at(sequence_id.id());
SPIEL_DCHECK_TRUE(node);
Expand Down
8 changes: 6 additions & 2 deletions open_spiel/algorithms/infostate_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
const std::vector<InfostateNode*>& start_nodes,
int max_move_ahead_limit = 1000);

class InfostateTree final {
class InfostateTree final : public std::enable_shared_from_this<InfostateTree> {
// Note that only MakeInfostateTree is allowed to call the constructor
// to ensure the trees are always allocated on heap. We do this so that all
// the collected pointers are valid throughout the tree's lifetime even if
Expand All @@ -308,6 +308,10 @@ class InfostateTree final {
const std::vector<const InfostateNode*>&, int);

public:
// -- gain shared ownership of the allocated infostate object
std::shared_ptr< InfostateTree > shared_ptr() { return shared_from_this(); }
std::shared_ptr< const InfostateTree > shared_ptr() const { return shared_from_this(); }

// -- Root accessors ---------------------------------------------------------
const InfostateNode& root() const { return *root_; }
InfostateNode* mutable_root() { return root_.get(); }
Expand Down Expand Up @@ -347,7 +351,7 @@ class InfostateTree final {
// Returns `None` if the sequence is the empty sequence.
absl::optional<DecisionId> DecisionIdForSequence(const SequenceId&) const;
// Returns `None` if the sequence is the empty sequence.
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId&);
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId& sequence_id) const;
// Returns whether the sequence ends with the last action the player can make.
bool IsLeafSequence(const SequenceId&) const;

Expand Down
64 changes: 44 additions & 20 deletions open_spiel/python/pybind11/algorithms_infostate_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

namespace py = ::pybind11;


namespace open_spiel {

using namespace algorithms;
Expand All @@ -31,7 +32,7 @@ using const_node_uniq_ptr = MockUniquePtr< const InfostateNode >;
void init_pyspiel_infostate_node(::pybind11::module &m)
{
py::class_< InfostateNode, infostatenode_holder_ptr >(m, "InfostateNode", py::is_final())
.def("tree", &InfostateNode::tree, py::return_value_policy::reference_internal)
.def("tree", [](const InfostateNode &node) { return node.tree().shared_ptr(); })
.def(
"parent", [](const InfostateNode &node) { return infostatenode_holder_ptr{node.parent()}; }
)
Expand Down Expand Up @@ -77,6 +78,7 @@ void init_pyspiel_infostate_node(::pybind11::module &m)
},
py::arg("index")
)
.def("make_certificate", &InfostateNode::MakeCertificate)
.def(
"__copy__",
[](const InfostateNode &node) {
Expand All @@ -89,14 +91,22 @@ void init_pyspiel_infostate_node(::pybind11::module &m)
);
}
)
.def("__deepcopy__", [](const InfostateNode &node) {
throw ForbiddenException(
"InfostateNode cannot be copied, because its "
"lifetime is managed by the owning "
"InfostateTree. Store a variable naming the "
"associated tree to ensure the node's "
"lifetime."
);
.def(
"__deepcopy__",
[](const InfostateNode &node) {
throw ForbiddenException(
"InfostateNode cannot be copied, because its "
"lifetime is managed by the owning "
"InfostateTree. Store a variable naming the "
"associated tree to ensure the node's "
"lifetime."
);
}
)
.def("address_str", [](const InfostateNode &node) {
std::stringstream ss;
ss << &node;
return ss.str();
});

py::enum_< InfostateNodeType >(m, "InfostateNodeType")
Expand Down Expand Up @@ -162,7 +172,7 @@ void init_pyspiel_infostate_tree(::pybind11::module &m)
m, "InfostateNodeVector2D"
);

py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree")
py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree", py::is_final())
.def(
py::init([](const Game &game, Player acting_player, int max_move_limit) {
return MakeInfostateTree(game, acting_player, max_move_limit);
Expand Down Expand Up @@ -240,12 +250,6 @@ void init_pyspiel_infostate_tree(::pybind11::module &m)
.def("is_leaf_sequence", &InfostateTree::IsLeafSequence)
.def(
"decision_infostate",
[](InfostateTree &tree, const DecisionId &id) {
return infostatenode_holder_ptr{tree.decision_infostate(id)};
}
)
.def(
"decision_infostate_view",
[](const InfostateTree &tree, const DecisionId &id) {
return const_node_uniq_ptr{tree.decision_infostate(id)};
}
Expand Down Expand Up @@ -308,10 +312,30 @@ void init_pyspiel_infostate_tree(::pybind11::module &m)
)
.def("best_response", &InfostateTree::BestResponse, py::arg("gradient"))
.def("best_response_value", &InfostateTree::BestResponseValue, py::arg("gradient"))
.def("__repr__", [](const InfostateTree &tree) {
std::ostringstream oss;
oss << tree;
return oss.str();
.def(
"__repr__",
[](const InfostateTree &tree) {
std::ostringstream oss;
oss << tree;
return oss.str();
}
)
.def(
"__copy__",
[](const InfostateTree &) {
throw ForbiddenException(
"InfostateTree cannot be copied, because its "
"internal structure is entangled during construction. "
"Create a new tree instead."
);
}
)
.def("__deepcopy__", [](const InfostateTree &) {
throw ForbiddenException(
"InfostateTree cannot be copied, because its "
"internal structure is entangled during construction. "
"Create a new tree instead."
);
});
}

Expand Down
219 changes: 215 additions & 4 deletions open_spiel/python/tests/infostate_tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,223 @@

"""Test Python bindings for infostate tree and related classes."""

from absl.testing import absltest
from absl.testing import absltest, parameterized

import pyspiel
import gc
from copy import copy, deepcopy
import weakref


class InfostateTreeTest(absltest.TestCase):
class InfostateTreeTest(parameterized.TestCase):
def test_tree_binding(self):
game = pyspiel.load_game("kuhn_poker")
tree = pyspiel.InfostateTree(game, 0)
self.assertEqual(tree.num_sequences(), 13)

def test_binding(self):
return False
# disallowing copying is enforced
with self.assertRaises(pyspiel.ForbiddenError) as context:
copy(tree)
deepcopy(tree)

def test_node_tree_lifetime_management(self):
root0 = tree.root()
# let's maintain a weak ref to the tree to see when the tree object is deallocated
wptr = weakref.ref(tree)
# ensure we can get a shared_ptr from root that keeps tree alive if we lose the 'tree' name
wptr_node = weakref.ref(root0)
tree_sptr = root0.tree()
# grab the tree id
id_tree0 = id(tree)
# now delete the initial tree ptr
del tree
# ensure that we still hold the object
gc.collect() # force garbage collection
self.assertIsNotNone(wptr())
self.assertEqual(id(tree_sptr), id_tree0)
# now delete the last pointer as well
del tree_sptr
gc.collect() # force garbage collection
self.assertIsNone(wptr())

@parameterized.parameters(
[
# test for matrix mp
dict(
game=pyspiel.load_game("matrix_mp"),
players=[0, 1],
expected_certificate="([" "({}{})" "({}{})" "])",
),
# test for imperfect info goofspiel
dict(
game=pyspiel.load_game(
"goofspiel",
{"num_cards": 2, "imp_info": True, "points_order": "ascending"},
),
players=[0, 1],
expected_certificate="([" "({}{})" "({}{})" "])",
),
# test for kuhn poker (0 player only)
dict(
game=pyspiel.load_game("kuhn_poker"),
players=[0],
expected_certificate=(
"((" # Root node, 1st is getting a card
"(" # 2nd is getting card
"[" # 1st acts
"((" # 1st bet, and 2nd acts
"(({}))"
"(({}))"
"(({}))"
"(({}))"
"))"
"((" # 1st checks, and 2nd acts
# 2nd checked
"(({}))"
"(({}))"
# 2nd betted
"[({}"
"{})"
"({}"
"{})]"
"))"
"]"
")"
# Just 2 more copies.
"([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])"
"([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])"
"))"
),
),
]
)
def test_root_certificates(self, game, players, expected_certificate):
for i in players:
tree = pyspiel.InfostateTree(game, i)
self.assertEqual(tree.root().make_certificate(), expected_certificate)

def check_tree_leaves(self, tree, move_limit):
for leaf_node in tree.leaf_nodes():
self.assertTrue(leaf_node.is_leaf_node())
self.assertTrue(leaf_node.has_infostate_string())
self.assertNotEmpty(leaf_node.corresponding_states())

num_states = len(leaf_node.corresponding_states())
terminal_cnt = 0
max_move_number = float("-inf")
min_move_number = float("inf")
for state in leaf_node.corresponding_states():
if state.is_terminal():
terminal_cnt += 1
max_move_number = max(max_move_number, state.move_number())
min_move_number = min(min_move_number, state.move_number())
self.assertTrue(terminal_cnt == 0 or terminal_cnt == num_states)
self.assertTrue(max_move_number == min_move_number)
if terminal_cnt == 0:
self.assertEqual(max_move_number, move_limit)
else:
self.assertLessEqual(max_move_number, move_limit)

def check_continuation(self, tree):
leaves = tree.nodes_at_depth(tree.tree_height())
continuation = pyspiel.InfostateTree(leaves)
self.assertEqual(continuation.root_branching_factor(), len(leaves))
for i in range(len(leaves)):
leaf_node = leaves[i]
root_node = continuation.root().child_at(i)
self.assertTrue(leaf_node.is_leaf_node())
if leaf_node.type() != pyspiel.InfostateNodeType.terminal:
self.assertEqual(leaf_node.type(), root_node.type())
self.assertEqual(
leaf_node.has_infostate_string(), root_node.has_infostate_string()
)
if leaf_node.has_infostate_string():
self.assertEqual(
leaf_node.infostate_string(), root_node.infostate_string()
)
else:
terminal_continuation = continuation.root().child_at(i)
while (
terminal_continuation.type()
== pyspiel.InfostateNodeType.observation
):
self.assertFalse(terminal_continuation.is_leaf_node())
self.assertEqual(terminal_continuation.num_children(), 1)
terminal_continuation = terminal_continuation.child_at(0)
self.assertEqual(
terminal_continuation.type(), pyspiel.InfostateNodeType.terminal
)
self.assertEqual(
leaf_node.has_infostate_string(),
terminal_continuation.has_infostate_string(),
)
if leaf_node.has_infostate_string():
self.assertEqual(
leaf_node.infostate_string(),
terminal_continuation.infostate_string(),
)
self.assertEqual(
leaf_node.terminal_utility(),
terminal_continuation.terminal_utility(),
)
self.assertEqual(
leaf_node.terminal_chance_reach_prob(),
terminal_continuation.terminal_chance_reach_prob(),
)
self.assertEqual(
leaf_node.terminal_history(),
terminal_continuation.terminal_history(),
)

def test_depth_limited_tree_kuhn_poker(self):
# Test MakeTree for Kuhn Poker with depth limit 2
expected_certificate = (
"(" # <dummy>
"(" # 1st is getting a card
"(" # 2nd is getting card
"[" # 1st acts - Node J
# Depth cutoff.
"]"
")"
# Repeat the same for the two other cards.
"([])" # Node Q
"([])" # Node K
")"
")" # </dummy>
)
tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), 0, 2)
self.assertEqual(tree.root().make_certificate(), expected_certificate)

# Test leaf nodes in Kuhn Poker tree
for acting in tree.leaf_nodes():
self.assertTrue(acting.is_leaf_node())
self.assertEqual(acting.type(), pyspiel.InfostateNodeType.decision)
self.assertEqual(len(acting.corresponding_states()), 2)
self.assertTrue(acting.has_infostate_string())

@parameterized.parameters(
[
"kuhn_poker",
"kuhn_poker(players=3)",
"leduc_poker",
"goofspiel(players=2,num_cards=3,imp_info=True)",
"goofspiel(players=3,num_cards=3,imp_info=True)",
]
)
def test_depth_limited_trees_all_depths(self, game_name):
game = pyspiel.load_game(game_name)
max_moves = game.max_move_number()
for move_limit in range(max_moves):
for pl in range(game.num_players()):
tree = pyspiel.InfostateTree(game, pl, move_limit)
self.check_tree_leaves(tree, move_limit)
self.check_continuation(tree)

def test_node_binding(self):
with self.assertRaises(TypeError) as context:
pyspiel.InfostateNode()
self.assertTrue("No constructor defined" in context.exception)


if __name__ == "__main__":
absltest.main()

0 comments on commit 7fbb749

Please sign in to comment.