Skip to content

Commit 63a9cb4

Browse files
committed
working
1 parent dc55164 commit 63a9cb4

File tree

5 files changed

+94
-51
lines changed

5 files changed

+94
-51
lines changed

arbiter-core/src/agent.rs

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,15 @@ pub struct Agent<L: LifeCycle, C: Connection> {
1515
inner: L,
1616
pub(crate) controller: Arc<Controller>,
1717
handlers: HashMap<TypeId, MessageHandlerFn<C>>,
18-
transport: Transport<C>,
18+
pub(crate) transport: Transport<C>,
19+
}
20+
21+
pub struct RunningAgent<C: Connection> {
22+
pub name: Option<String>,
23+
pub(crate) task: JoinHandle<Box<dyn RuntimeAgent<C>>>,
24+
pub(crate) controller: Arc<Controller>,
25+
pub(crate) outbound_connections: Arc<Mutex<HashMap<C::Address, C::Sender>>>,
26+
pub(crate) sender: C::Sender,
1927
}
2028

2129
pub struct Controller {
@@ -86,11 +94,6 @@ impl<L: LifeCycle, C: Connection> Agent<L, C> {
8694

8795
pub fn clear_name(&mut self) { self.name = None; }
8896

89-
// Access to inner L, requires locking.
90-
pub fn inner_mut(&mut self) -> &mut L { &mut self.inner }
91-
92-
pub fn inner(&self) -> &L { &self.inner }
93-
9497
pub fn with_handler<M>(mut self) -> Self
9598
where
9699
M: Message,
@@ -137,10 +140,8 @@ pub trait RuntimeAgent<C: Connection>: Send + Sync + Any {
137140
// Methods to signal the agent's desired state
138141
fn signal_start(&self);
139142
fn signal_stop(&self);
140-
fn current_loop_state(&self) -> State;
141143

142-
fn process(self) -> JoinHandle<Self>
143-
where Self: Sized;
144+
fn process(self) -> RunningAgent<C>;
144145

145146
#[cfg(test)]
146147
fn inner_as_any(&self) -> &dyn Any;
@@ -175,12 +176,16 @@ where C::Payload: Package<L::StartMessage> + Package<L::StopMessage>
175176
}
176177
}
177178

178-
fn current_loop_state(&self) -> State { *self.controller.state.lock().unwrap() }
179+
fn process(mut self) -> RunningAgent<C> {
180+
let name = self.name.clone();
181+
let outbound_connections = self.transport.outbound_connections.clone();
182+
let controller = self.controller.clone();
183+
let (_, sender) = self.transport.create_inbound_connection();
179184

180-
fn process(mut self) -> JoinHandle<Self> {
181-
std::thread::spawn(move || {
185+
let handle = std::thread::spawn(move || {
182186
let start_message = self.inner.on_start();
183187
self.transport.broadcast(Envelope::package(start_message));
188+
println!("Agent {}: Sent start message.", self.address());
184189

185190
loop {
186191
println!("Agent {}: Loop iteration.", self.address());
@@ -269,8 +274,9 @@ where C::Payload: Package<L::StartMessage> + Package<L::StopMessage>
269274

270275
self.inner.on_stop();
271276
println!("Agent {}: Executed on_stop. Main loop finished.", self.address());
272-
self
273-
})
277+
Box::new(self) as Box<dyn RuntimeAgent<C>>
278+
});
279+
RunningAgent { name, task: handle, controller, outbound_connections, sender }
274280
}
275281

276282
#[cfg(test)]
@@ -287,20 +293,21 @@ mod tests {
287293
fn test_agent_lifecycle() {
288294
let logger = Logger { name: "TestLogger".to_string(), message_count: 0 };
289295
let agent = Agent::<Logger, InMemory>::new(logger);
290-
assert_eq!(agent.current_loop_state(), State::Stopped);
296+
assert_eq!(*agent.controller.state.lock().unwrap(), State::Stopped);
291297

292298
// Start agent
293299
let controller = agent.controller.clone();
294-
let handle = agent.process();
300+
let running_agent = agent.process();
295301

296302
controller.signal_start();
297303
std::thread::sleep(std::time::Duration::from_millis(50)); // Give time for agent to start
298304
assert_eq!(controller.state.lock().unwrap().clone(), State::Running);
299305

300306
controller.signal_stop();
301-
let joined_agent = handle.join().unwrap();
302-
assert_eq!(joined_agent.inner().message_count, 0);
303-
assert_eq!(joined_agent.current_loop_state(), State::Stopped);
307+
let joined_agent = running_agent.task.join().unwrap();
308+
let agent = joined_agent.inner_as_any().downcast_ref::<Logger>().unwrap();
309+
assert_eq!(agent.message_count, 0);
310+
assert_eq!(*running_agent.controller.state.lock().unwrap(), State::Stopped);
304311
}
305312

306313
#[test]
@@ -310,7 +317,7 @@ mod tests {
310317
let sender = agent.transport.inbound_connection.sender.clone();
311318

312319
let controller = agent.controller.clone();
313-
let handle = agent.process();
320+
let running_agent = agent.process();
314321

315322
controller.signal_start();
316323
std::thread::sleep(std::time::Duration::from_millis(50)); // Allow agent to start and enter wait
@@ -321,9 +328,9 @@ mod tests {
321328
assert_eq!(controller.state.lock().unwrap().clone(), State::Running);
322329

323330
controller.signal_stop();
324-
let result_agent = handle.join().unwrap();
325-
326-
assert_eq!(result_agent.inner().message_count, 1);
331+
let result_agent = running_agent.task.join().unwrap();
332+
let agent = result_agent.inner_as_any().downcast_ref::<Logger>().unwrap();
333+
assert_eq!(agent.message_count, 1);
327334
}
328335

329336
#[test]
@@ -335,10 +342,10 @@ mod tests {
335342
agent_struct = agent_struct.with_handler::<TextMessage>().with_handler::<NumberMessage>();
336343
let sender = agent_struct.transport.inbound_connection.sender.clone();
337344

338-
assert_eq!(agent_struct.current_loop_state(), State::Stopped);
345+
assert_eq!(*agent_struct.controller.state.lock().unwrap(), State::Stopped);
339346

340347
let controller = agent_struct.controller.clone();
341-
let handle = agent_struct.process();
348+
let running_agent = agent_struct.process();
342349

343350
controller.signal_start();
344351
std::thread::sleep(std::time::Duration::from_millis(50));
@@ -352,7 +359,8 @@ mod tests {
352359
std::thread::sleep(std::time::Duration::from_millis(50));
353360

354361
controller.signal_stop();
355-
let result_agent = handle.join().unwrap();
356-
assert_eq!(result_agent.inner().message_count, 2);
362+
let result_agent = running_agent.task.join().unwrap();
363+
let agent = result_agent.inner_as_any().downcast_ref::<Logger>().unwrap();
364+
assert_eq!(agent.message_count, 2);
357365
}
358366
}

arbiter-core/src/connection/memory.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Arc;
22

33
use crate::{
44
agent::AgentIdentity,
5-
connection::{Connection, Receiver, Sender},
5+
connection::{Connection, GetNew, Receiver, Sender},
66
handler::{Envelope, Message},
77
};
88

@@ -13,6 +13,10 @@ pub struct InMemory {
1313
receiver: flume::Receiver<Envelope<Self>>,
1414
}
1515

16+
impl GetNew for flume::Sender<Envelope<InMemory>> {
17+
fn get_new(&self) -> Self { self.clone() }
18+
}
19+
1620
impl Sender for flume::Sender<Envelope<InMemory>> {
1721
type Connection = InMemory;
1822

arbiter-core/src/connection/mod.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
use std::{collections::HashMap, future::Future, hash::Hash, pin::Pin};
1+
use std::{
2+
collections::HashMap,
3+
future::Future,
4+
hash::Hash,
5+
pin::Pin,
6+
sync::{Arc, Mutex},
7+
};
28

39
use crate::handler::{Envelope, Message, Package};
410

@@ -26,8 +32,12 @@ impl Runtime for AsyncRuntime {
2632
fn wrap<T: 'static>(value: T) -> Self::Output<T> { Box::pin(async move { value }) }
2733
}
2834

35+
pub trait GetNew {
36+
fn get_new(&self) -> Self;
37+
}
38+
2939
// TODO: Need to have results for send and receive.
30-
pub trait Sender: Send + Sync + 'static {
40+
pub trait Sender: GetNew + Send + Sync + 'static {
3141
type Connection: Connection;
3242
fn send(&self, envelope: Envelope<Self::Connection>);
3343
}
@@ -56,27 +66,35 @@ pub trait Connection: Send + Sync + 'static {
5666

5767
pub struct Transport<C: Connection> {
5868
pub(crate) inbound_connection: C,
59-
pub(crate) outbound_connections: HashMap<C::Address, C::Sender>,
69+
pub(crate) outbound_connections: Arc<Mutex<HashMap<C::Address, C::Sender>>>,
6070
}
6171

6272
// TODO: These should return results
6373
impl<C: Connection> Transport<C> {
6474
pub fn new() -> Self {
65-
Self { inbound_connection: C::new(), outbound_connections: HashMap::new() }
75+
Self {
76+
inbound_connection: C::new(),
77+
outbound_connections: Arc::new(Mutex::new(HashMap::new())),
78+
}
6679
}
6780

6881
pub fn send(&mut self, envelope: Envelope<C>, address: C::Address) {
69-
self.outbound_connections.get_mut(&address).map(|sender| sender.send(envelope));
82+
self.outbound_connections.lock().unwrap().get_mut(&address).map(|sender| sender.send(envelope));
7083
}
7184

7285
pub fn broadcast(&mut self, envelope: Envelope<C>) {
73-
self.outbound_connections.values_mut().for_each(|sender| sender.send(envelope.clone()));
86+
self
87+
.outbound_connections
88+
.lock()
89+
.unwrap()
90+
.values_mut()
91+
.for_each(|sender| sender.send(envelope.clone()));
7492
}
7593

7694
pub fn receive(&mut self) -> Option<Envelope<C>> { self.inbound_connection.receiver().receive() }
7795

7896
pub fn add_outbound_connection(&mut self, address: C::Address, sender: C::Sender) {
79-
self.outbound_connections.insert(address, sender);
97+
self.outbound_connections.lock().unwrap().insert(address, sender);
8098
}
8199

82100
pub fn create_inbound_connection(&self) -> (C::Address, C::Sender) {

arbiter-core/src/connection/tcp.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ use std::{
66
};
77

88
use crate::{
9-
connection::{Connection, Receiver, Sender},
9+
connection::{Connection, GetNew, Receiver, Sender},
1010
handler::Envelope,
1111
};
1212

1313
pub struct Tcp {
1414
pub(crate) stream: TcpStream,
1515
}
1616

17+
impl GetNew for TcpStream {
18+
fn get_new(&self) -> Self { self.try_clone().unwrap() }
19+
}
20+
1721
impl Sender for TcpStream {
1822
type Connection = Tcp;
1923

arbiter-core/src/fabric.rs

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
use std::collections::HashMap;
1+
use std::{
2+
collections::HashMap,
3+
sync::{Arc, Mutex},
4+
thread::JoinHandle,
5+
};
26

37
use crate::{
4-
agent::{Agent, LifeCycle, RuntimeAgent, State},
5-
connection::{memory::InMemory, Connection},
8+
agent::{Agent, LifeCycle, RunningAgent, RuntimeAgent, State},
9+
connection::{memory::InMemory, Connection, GetNew, Transport},
610
handler::Package,
711
};
812

@@ -11,7 +15,7 @@ use crate::{
1115
// and an atomic for their state)
1216
pub struct Fabric<C: Connection> {
1317
id: FabricId,
14-
agents: HashMap<C::Address, Box<dyn RuntimeAgent<C>>>,
18+
agents: HashMap<C::Address, RunningAgent<C>>,
1519
name_to_id: HashMap<String, C::Address>,
1620
}
1721

@@ -40,18 +44,20 @@ impl<C: Connection> Fabric<C> {
4044
}
4145
self.name_to_id.insert(name, new_agent.address());
4246
}
47+
4348
let id = new_agent.address();
44-
for agent in self.agents.values_mut() {
49+
for (agent_address, agent) in self.agents.iter_mut() {
4550
// Get the new agent's inbound connection and give it to all the other agents as an outbound
4651
// connection.
47-
let (address, sender) = new_agent.transport().create_inbound_connection();
48-
agent.add_outbound_connection(address, sender);
52+
let (address, sender) = new_agent.transport.create_inbound_connection();
53+
agent.outbound_connections.lock().unwrap().insert(address, sender);
54+
4955
// Get the other agent's inbound connection and give it to the new agent as an outbound
5056
// connection.
51-
let (address, sender) = agent.transport().create_inbound_connection();
52-
new_agent.add_outbound_connection(address, sender);
57+
new_agent.add_outbound_connection(*agent_address, agent.sender.get_new());
5358
}
54-
self.agents.insert(id, Box::new(new_agent));
59+
let new_agent = new_agent.process();
60+
self.agents.insert(id, new_agent);
5561
id
5662
}
5763

@@ -65,7 +71,7 @@ impl<C: Connection> Fabric<C> {
6571
A: LifeCycle,
6672
C::Payload: Package<A::StartMessage> + Package<A::StopMessage>,
6773
{
68-
agent.name = Some(name.into());
74+
agent.set_name(name);
6975
Ok(self.register_agent(agent))
7076
}
7177

@@ -104,7 +110,7 @@ impl<C: Connection> Fabric<C> {
104110
self.agents.get_mut(&agent_id).map_or_else(
105111
|| Err(format!("Agent with ID {agent_id} not found")),
106112
|agent| {
107-
agent.signal_start();
113+
agent.controller.signal_start();
108114
Ok(())
109115
},
110116
)
@@ -124,7 +130,7 @@ found"
124130
// TODO: This is a bit clunky. Perhaps we can combine these states.
125131
/// Get the state of an agent by ID
126132
pub fn agent_state_by_id(&self, agent_id: C::Address) -> Option<State> {
127-
self.agents.get(&agent_id).map(|agent| agent.current_loop_state())
133+
self.agents.get(&agent_id).map(|agent| *agent.controller.state.lock().unwrap())
128134
}
129135

130136
/// Get agent count
@@ -138,7 +144,7 @@ impl<C: Connection> Fabric<C> {
138144
/// Execute a single fabric step: poll transport and process messages
139145
pub fn start(&mut self) {
140146
for agent in self.agents.values_mut() {
141-
agent.signal_start();
147+
agent.controller.signal_start();
142148
}
143149
}
144150
}
@@ -230,7 +236,10 @@ mod tests {
230236
.unwrap();
231237

232238
fabric.start();
233-
let logger = fabric.agents.get(&logger_id).unwrap();
239+
std::thread::sleep(std::time::Duration::from_millis(50));
240+
let running_logger = fabric.agents.remove(&logger_id).unwrap();
241+
running_logger.controller.signal_stop();
242+
let logger = running_logger.task.join().unwrap();
234243
let logger = logger.inner_as_any().downcast_ref::<Logger>().unwrap();
235244
assert_eq!(logger.message_count, 1);
236245
}

0 commit comments

Comments
 (0)