From 12b7139d47d5c5f790a95b7f90061733605f9535 Mon Sep 17 00:00:00 2001 From: Michael Aichmueller Date: Fri, 21 Apr 2023 13:01:44 +0200 Subject: [PATCH] further test additions, added enable_shared_from_this to istatetree --- open_spiel/algorithms/infostate_tree.cc | 3 +- open_spiel/algorithms/infostate_tree.h | 8 +- .../pybind11/algorithms_infostate_tree.cc | 64 +++-- .../python/tests/infostate_tree_test.py | 219 +++++++++++++++++- 4 files changed, 267 insertions(+), 27 deletions(-) diff --git a/open_spiel/algorithms/infostate_tree.cc b/open_spiel/algorithms/infostate_tree.cc index b01b3ebb9a..df0e0c6435 100644 --- a/open_spiel/algorithms/infostate_tree.cc +++ b/open_spiel/algorithms/infostate_tree.cc @@ -485,7 +485,8 @@ absl::optional InfostateTree::DecisionIdForSequence( } } absl::optional InfostateTree::DecisionForSequence( - const SequenceId& sequence_id) { + const SequenceId& sequence_id) const +{ SPIEL_DCHECK_TRUE(sequence_id.BelongsToTree(this)); InfostateNode* node = sequences_.at(sequence_id.id()); SPIEL_DCHECK_TRUE(node); diff --git a/open_spiel/algorithms/infostate_tree.h b/open_spiel/algorithms/infostate_tree.h index efed0c9bd9..f4f8a58931 100644 --- a/open_spiel/algorithms/infostate_tree.h +++ b/open_spiel/algorithms/infostate_tree.h @@ -288,7 +288,7 @@ std::shared_ptr MakeInfostateTree( const std::vector& start_nodes, int max_move_ahead_limit = 1000); -class InfostateTree final { +class InfostateTree final : public std::enable_shared_from_this { // Note that only MakeInfostateTree is allowed to call the constructor // to ensure the trees are always allocated on heap. We do this so that all // the collected pointers are valid throughout the tree's lifetime even if @@ -308,6 +308,10 @@ class InfostateTree final { const std::vector&, int); public: + // -- gain shared ownership of the allocated infostate object + std::shared_ptr< InfostateTree > shared_ptr() { return shared_from_this(); } + std::shared_ptr< const InfostateTree > shared_ptr() const { return shared_from_this(); } + // -- Root accessors --------------------------------------------------------- const InfostateNode& root() const { return *root_; } InfostateNode* mutable_root() { return root_.get(); } @@ -347,7 +351,7 @@ class InfostateTree final { // Returns `None` if the sequence is the empty sequence. absl::optional DecisionIdForSequence(const SequenceId&) const; // Returns `None` if the sequence is the empty sequence. - absl::optional DecisionForSequence(const SequenceId&); + absl::optional DecisionForSequence(const SequenceId& sequence_id) const; // Returns whether the sequence ends with the last action the player can make. bool IsLeafSequence(const SequenceId&) const; diff --git a/open_spiel/python/pybind11/algorithms_infostate_tree.cc b/open_spiel/python/pybind11/algorithms_infostate_tree.cc index 6adb411fc1..75791a8e65 100644 --- a/open_spiel/python/pybind11/algorithms_infostate_tree.cc +++ b/open_spiel/python/pybind11/algorithms_infostate_tree.cc @@ -21,6 +21,7 @@ namespace py = ::pybind11; + namespace open_spiel { using namespace algorithms; @@ -31,7 +32,7 @@ using const_node_uniq_ptr = MockUniquePtr< const InfostateNode >; void init_pyspiel_infostate_node(::pybind11::module &m) { py::class_< InfostateNode, infostatenode_holder_ptr >(m, "InfostateNode", py::is_final()) - .def("tree", &InfostateNode::tree, py::return_value_policy::reference_internal) + .def("tree", [](const InfostateNode &node) { return node.tree().shared_ptr(); }) .def( "parent", [](const InfostateNode &node) { return infostatenode_holder_ptr{node.parent()}; } ) @@ -77,6 +78,7 @@ void init_pyspiel_infostate_node(::pybind11::module &m) }, py::arg("index") ) + .def("make_certificate", &InfostateNode::MakeCertificate) .def( "__copy__", [](const InfostateNode &node) { @@ -89,14 +91,22 @@ void init_pyspiel_infostate_node(::pybind11::module &m) ); } ) - .def("__deepcopy__", [](const InfostateNode &node) { - throw ForbiddenException( - "InfostateNode cannot be copied, because its " - "lifetime is managed by the owning " - "InfostateTree. Store a variable naming the " - "associated tree to ensure the node's " - "lifetime." - ); + .def( + "__deepcopy__", + [](const InfostateNode &node) { + throw ForbiddenException( + "InfostateNode cannot be copied, because its " + "lifetime is managed by the owning " + "InfostateTree. Store a variable naming the " + "associated tree to ensure the node's " + "lifetime." + ); + } + ) + .def("address_str", [](const InfostateNode &node) { + std::stringstream ss; + ss << &node; + return ss.str(); }); py::enum_< InfostateNodeType >(m, "InfostateNodeType") @@ -162,7 +172,7 @@ void init_pyspiel_infostate_tree(::pybind11::module &m) m, "InfostateNodeVector2D" ); - py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree") + py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree", py::is_final()) .def( py::init([](const Game &game, Player acting_player, int max_move_limit) { return MakeInfostateTree(game, acting_player, max_move_limit); @@ -240,12 +250,6 @@ void init_pyspiel_infostate_tree(::pybind11::module &m) .def("is_leaf_sequence", &InfostateTree::IsLeafSequence) .def( "decision_infostate", - [](InfostateTree &tree, const DecisionId &id) { - return infostatenode_holder_ptr{tree.decision_infostate(id)}; - } - ) - .def( - "decision_infostate_view", [](const InfostateTree &tree, const DecisionId &id) { return const_node_uniq_ptr{tree.decision_infostate(id)}; } @@ -308,10 +312,30 @@ void init_pyspiel_infostate_tree(::pybind11::module &m) ) .def("best_response", &InfostateTree::BestResponse, py::arg("gradient")) .def("best_response_value", &InfostateTree::BestResponseValue, py::arg("gradient")) - .def("__repr__", [](const InfostateTree &tree) { - std::ostringstream oss; - oss << tree; - return oss.str(); + .def( + "__repr__", + [](const InfostateTree &tree) { + std::ostringstream oss; + oss << tree; + return oss.str(); + } + ) + .def( + "__copy__", + [](const InfostateTree &) { + throw ForbiddenException( + "InfostateTree cannot be copied, because its " + "internal structure is entangled during construction. " + "Create a new tree instead." + ); + } + ) + .def("__deepcopy__", [](const InfostateTree &) { + throw ForbiddenException( + "InfostateTree cannot be copied, because its " + "internal structure is entangled during construction. " + "Create a new tree instead." + ); }); } diff --git a/open_spiel/python/tests/infostate_tree_test.py b/open_spiel/python/tests/infostate_tree_test.py index f30e6345c8..1a1ec0a033 100644 --- a/open_spiel/python/tests/infostate_tree_test.py +++ b/open_spiel/python/tests/infostate_tree_test.py @@ -14,12 +14,223 @@ """Test Python bindings for infostate tree and related classes.""" -from absl.testing import absltest +from absl.testing import absltest, parameterized import pyspiel +import gc +from copy import copy, deepcopy +import weakref -class InfostateTreeTest(absltest.TestCase): +class InfostateTreeTest(parameterized.TestCase): + def test_tree_binding(self): + game = pyspiel.load_game("kuhn_poker") + tree = pyspiel.InfostateTree(game, 0) + self.assertEqual(tree.num_sequences(), 13) - def test_binding(self): - return False + # disallowing copying is enforced + with self.assertRaises(pyspiel.ForbiddenError) as context: + copy(tree) + deepcopy(tree) + + def test_node_tree_lifetime_management(self): + root0 = tree.root() + # let's maintain a weak ref to the tree to see when the tree object is deallocated + wptr = weakref.ref(tree) + # ensure we can get a shared_ptr from root that keeps tree alive if we lose the 'tree' name + wptr_node = weakref.ref(root0) + tree_sptr = root0.tree() + # grab the tree id + id_tree0 = id(tree) + # now delete the initial tree ptr + del tree + # ensure that we still hold the object + gc.collect() # force garbage collection + self.assertIsNotNone(wptr()) + self.assertEqual(id(tree_sptr), id_tree0) + # now delete the last pointer as well + del tree_sptr + gc.collect() # force garbage collection + self.assertIsNone(wptr()) + + @parameterized.parameters( + [ + # test for matrix mp + dict( + game=pyspiel.load_game("matrix_mp"), + players=[0, 1], + expected_certificate="([" "({}{})" "({}{})" "])", + ), + # test for imperfect info goofspiel + dict( + game=pyspiel.load_game( + "goofspiel", + {"num_cards": 2, "imp_info": True, "points_order": "ascending"}, + ), + players=[0, 1], + expected_certificate="([" "({}{})" "({}{})" "])", + ), + # test for kuhn poker (0 player only) + dict( + game=pyspiel.load_game("kuhn_poker"), + players=[0], + expected_certificate=( + "((" # Root node, 1st is getting a card + "(" # 2nd is getting card + "[" # 1st acts + "((" # 1st bet, and 2nd acts + "(({}))" + "(({}))" + "(({}))" + "(({}))" + "))" + "((" # 1st checks, and 2nd acts + # 2nd checked + "(({}))" + "(({}))" + # 2nd betted + "[({}" + "{})" + "({}" + "{})]" + "))" + "]" + ")" + # Just 2 more copies. + "([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])" + "([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])" + "))" + ), + ), + ] + ) + def test_root_certificates(self, game, players, expected_certificate): + for i in players: + tree = pyspiel.InfostateTree(game, i) + self.assertEqual(tree.root().make_certificate(), expected_certificate) + + def check_tree_leaves(self, tree, move_limit): + for leaf_node in tree.leaf_nodes(): + self.assertTrue(leaf_node.is_leaf_node()) + self.assertTrue(leaf_node.has_infostate_string()) + self.assertNotEmpty(leaf_node.corresponding_states()) + + num_states = len(leaf_node.corresponding_states()) + terminal_cnt = 0 + max_move_number = float("-inf") + min_move_number = float("inf") + for state in leaf_node.corresponding_states(): + if state.is_terminal(): + terminal_cnt += 1 + max_move_number = max(max_move_number, state.move_number()) + min_move_number = min(min_move_number, state.move_number()) + self.assertTrue(terminal_cnt == 0 or terminal_cnt == num_states) + self.assertTrue(max_move_number == min_move_number) + if terminal_cnt == 0: + self.assertEqual(max_move_number, move_limit) + else: + self.assertLessEqual(max_move_number, move_limit) + + def check_continuation(self, tree): + leaves = tree.nodes_at_depth(tree.tree_height()) + continuation = pyspiel.InfostateTree(leaves) + self.assertEqual(continuation.root_branching_factor(), len(leaves)) + for i in range(len(leaves)): + leaf_node = leaves[i] + root_node = continuation.root().child_at(i) + self.assertTrue(leaf_node.is_leaf_node()) + if leaf_node.type() != pyspiel.InfostateNodeType.terminal: + self.assertEqual(leaf_node.type(), root_node.type()) + self.assertEqual( + leaf_node.has_infostate_string(), root_node.has_infostate_string() + ) + if leaf_node.has_infostate_string(): + self.assertEqual( + leaf_node.infostate_string(), root_node.infostate_string() + ) + else: + terminal_continuation = continuation.root().child_at(i) + while ( + terminal_continuation.type() + == pyspiel.InfostateNodeType.observation + ): + self.assertFalse(terminal_continuation.is_leaf_node()) + self.assertEqual(terminal_continuation.num_children(), 1) + terminal_continuation = terminal_continuation.child_at(0) + self.assertEqual( + terminal_continuation.type(), pyspiel.InfostateNodeType.terminal + ) + self.assertEqual( + leaf_node.has_infostate_string(), + terminal_continuation.has_infostate_string(), + ) + if leaf_node.has_infostate_string(): + self.assertEqual( + leaf_node.infostate_string(), + terminal_continuation.infostate_string(), + ) + self.assertEqual( + leaf_node.terminal_utility(), + terminal_continuation.terminal_utility(), + ) + self.assertEqual( + leaf_node.terminal_chance_reach_prob(), + terminal_continuation.terminal_chance_reach_prob(), + ) + self.assertEqual( + leaf_node.terminal_history(), + terminal_continuation.terminal_history(), + ) + + def test_depth_limited_tree_kuhn_poker(self): + # Test MakeTree for Kuhn Poker with depth limit 2 + expected_certificate = ( + "(" # + "(" # 1st is getting a card + "(" # 2nd is getting card + "[" # 1st acts - Node J + # Depth cutoff. + "]" + ")" + # Repeat the same for the two other cards. + "([])" # Node Q + "([])" # Node K + ")" + ")" # + ) + tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), 0, 2) + self.assertEqual(tree.root().make_certificate(), expected_certificate) + + # Test leaf nodes in Kuhn Poker tree + for acting in tree.leaf_nodes(): + self.assertTrue(acting.is_leaf_node()) + self.assertEqual(acting.type(), pyspiel.InfostateNodeType.decision) + self.assertEqual(len(acting.corresponding_states()), 2) + self.assertTrue(acting.has_infostate_string()) + + @parameterized.parameters( + [ + "kuhn_poker", + "kuhn_poker(players=3)", + "leduc_poker", + "goofspiel(players=2,num_cards=3,imp_info=True)", + "goofspiel(players=3,num_cards=3,imp_info=True)", + ] + ) + def test_depth_limited_trees_all_depths(self, game_name): + game = pyspiel.load_game(game_name) + max_moves = game.max_move_number() + for move_limit in range(max_moves): + for pl in range(game.num_players()): + tree = pyspiel.InfostateTree(game, pl, move_limit) + self.check_tree_leaves(tree, move_limit) + self.check_continuation(tree) + + def test_node_binding(self): + with self.assertRaises(TypeError) as context: + pyspiel.InfostateNode() + self.assertTrue("No constructor defined" in context.exception) + + +if __name__ == "__main__": + absltest.main()