diff --git a/include/models/ActivityDrivenModel.hpp b/include/models/ActivityDrivenModel.hpp index 64a1223..9a1f6dd 100644 --- a/include/models/ActivityDrivenModel.hpp +++ b/include/models/ActivityDrivenModel.hpp @@ -291,15 +291,17 @@ class ActivityDrivenModelAbstract : public Model double prob_contact_ji = contact_prob_list[j][idx_agent]; // Set the incoming agent weight, j-i in weight list - double & win_ji = network.get_weights( j )[idx_agent]; + double win_ji = network.get_edge_weight( j ,idx_agent); win_ji += prob_contact_ij; + network.set_edge_weight(j, idx_agent, win_ji); // Handle the reciprocity for j->i // Update incoming weight i-j - double & win_ij = network.get_weights( idx_agent )[j]; + double win_ij = network.get_weights( idx_agent )[j]; // The probability of reciprocating is win_ij += ( 1.0 - prob_contact_ji ) * reciprocity * prob_contact_ij; + network.set_edge_weight(idx_agent, j, win_ij); } } } diff --git a/include/network.hpp b/include/network.hpp index 8056b62..7aeafd0 100644 --- a/include/network.hpp +++ b/include/network.hpp @@ -121,11 +121,6 @@ class Network return std::span( neighbour_list[agent_idx].data(), neighbour_list[agent_idx].size() ); } - [[nodiscard]] std::span get_neighbours( std::size_t agent_idx ) - { - return std::span( neighbour_list[agent_idx].data(), neighbour_list[agent_idx].size() ); - } - /* Gives a view into the edge weights going out/coming in at agent_idx */ @@ -134,13 +129,8 @@ class Network return std::span( weight_list[agent_idx].data(), weight_list[agent_idx].size() ); } - [[nodiscard]] std::span get_weights( std::size_t agent_idx ) - { - return std::span( weight_list[agent_idx].data(), weight_list[agent_idx].size() ); - } - /* - Gives a view into the edge weights going out/coming in at agent_idx + Set the edge weights going out/coming in at agent_idx */ void set_weights( std::size_t agent_idx, const std::span weights ) { @@ -151,6 +141,15 @@ class Network weight_list[agent_idx].assign( weights.begin(), weights.end() ); } + /* + Sets the neighbour indices + */ + void set_edge( + std::size_t agent_idx, std::size_t index_neighbour, std::size_t agent_jdx ) + { + neighbour_list[agent_idx][index_neighbour]=agent_jdx; + } + /* Sets the neighbour indices and sets the weight to a constant value at agent_idx */ @@ -178,6 +177,21 @@ class Network weight_list[agent_idx].assign( buffer_weights.begin(), buffer_weights.end() ); } + /* + Sets the weight for agent_idx, for a neighbour index + */ + void set_edge_weight(std::size_t agent_idx, std::size_t index_neighbour, WeightT weight){ + weight_list[agent_idx][index_neighbour] = weight; + } + + /* + Gets the weight for agent_idx, for a neighbour index + */ + const WeightT get_edge_weight(std::size_t agent_idx, std::size_t index_neighbour) const + { + return weight_list[agent_idx][index_neighbour]; + } + /* Adds an edge between agent_idx_i and agent_idx_j with weight w */ diff --git a/test/test_network.cpp b/test/test_network.cpp index f1d2bfa..218060c 100644 --- a/test/test_network.cpp +++ b/test/test_network.cpp @@ -40,14 +40,17 @@ TEST_CASE( "Testing the network class" ) REQUIRE_THAT( weight, Catch::Matchers::UnorderedRangeEquals( buffer_w_get ) ); REQUIRE( network.n_edges( 3 ) == 2 ); - size_t & n = network.get_neighbours( 3 )[0]; + size_t n = network.get_neighbours( 3 )[0]; REQUIRE( n == neigh[0] ); n = 2; + // Set the neighbour + network.set_edge(3, 0, n ); REQUIRE( network.get_neighbours( 3 )[0] == 2 ); - Network::WeightT & w = network.get_weights( 3 )[1]; + Network::WeightT w = network.get_weights( 3 )[1]; REQUIRE( w == 0.55 ); w = 0.9; + network.set_edge_weight(3, 1, w); REQUIRE( network.get_weights( 3 )[1] == w ); }