@@ -19,23 +19,41 @@ TEST_CASE( "Testing the network class" )
19
19
// Does n_agents work?
20
20
REQUIRE ( network->n_agents () == n_agents );
21
21
22
+ // Check that the function for setting neighbours and a single weight work
23
+ // Agent 3
24
+ std::vector<size_t > buffer_n_get{}; // buffer for getting neighbours
25
+ std::vector<Seldon::Network::WeightT> buffer_w_get{}; // buffer for getting the weights
26
+ std::vector<size_t > neigh{ { 0 , 10 } }; // new neighbours
27
+ std::vector<Seldon::Network::WeightT> weight{ 0.5 , 0.5 }; // new weights (const)
28
+ network->set_neighbours_and_weights ( 3 , neigh, 0.5 );
29
+ network->get_weights ( 3 , buffer_w_get );
30
+ REQUIRE_THAT ( buffer_w_get, Catch::Matchers::UnorderedRangeEquals ( buffer_w_get ) );
31
+
22
32
// Change the connections for agent 3
23
- std::vector<size_t > buffer_n{ { 0 , 10 , 15 } };
24
- std::vector<Seldon::Network::WeightT> buffer_w{ 0.1 , 0.2 , 0.3 };
33
+ std::vector<size_t > buffer_n{ { 0 , 10 , 15 } }; // new neighbours
34
+ std::vector<Seldon::Network::WeightT> buffer_w{ 0.1 , 0.2 , 0.3 }; // new weights
25
35
network->set_neighbours_and_weights ( 3 , buffer_n, buffer_w );
26
36
27
37
// Make sure the changes worked
28
- std::vector<size_t > buffer_n_get{};
29
- std::vector<Seldon::Network::WeightT> buffer_w_get{};
30
38
network->get_neighbours ( 3 , buffer_n_get );
31
39
network->get_weights ( 3 , buffer_w_get );
32
40
33
41
REQUIRE_THAT ( buffer_n_get, Catch::Matchers::UnorderedRangeEquals ( buffer_n ) );
34
42
REQUIRE_THAT ( buffer_w_get, Catch::Matchers::UnorderedRangeEquals ( buffer_w ) );
35
43
44
+ // Check that the push_back function works for agent 3
45
+ buffer_n.push_back ( 5 ); // new neighbour
46
+ buffer_w.push_back ( 1.0 ); // new weight for this new connection
47
+ network->push_back_neighbour_and_weight ( 3 , 5 , 1.0 ); // new connection added with weight
48
+ // Check that the change worked for the push_back function
49
+ network->get_neighbours ( 3 , buffer_n_get );
50
+ network->get_weights ( 3 , buffer_w_get );
51
+ REQUIRE_THAT ( buffer_n_get, Catch::Matchers::UnorderedRangeEquals ( buffer_n ) );
52
+ REQUIRE_THAT ( buffer_w_get, Catch::Matchers::UnorderedRangeEquals ( buffer_w ) );
53
+
36
54
// Now we test the transpose() function
37
55
38
- // First record all the old edges as tupels (i,j,w) where this edge goes from j -> i with weight w
56
+ // First record all the old edges as tuples (i,j,w) where this edge goes from j -> i with weight w
39
57
std::set<std::tuple<size_t , size_t , Network::WeightT>> old_edges;
40
58
for ( size_t i_agent = 0 ; i_agent < network->n_agents (); i_agent++ )
41
59
{
0 commit comments