Skip to content

Commit

Permalink
Refactor TCPReadHandle and TCPWriteHandle control flow (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored Feb 23, 2025
1 parent 60fb823 commit 09b573f
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 60 deletions.
2 changes: 1 addition & 1 deletion rloop/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def _ensure_fd_no_transport(self, fd):
try:
fileno = int(fileno.fileno())
except (AttributeError, TypeError, ValueError):
raise ValueError(f"Invalid file object: {fd!r}") from None
raise ValueError(f'Invalid file object: {fd!r}') from None
if self._tcp_stream_bound(fileno):
raise RuntimeError(f'File descriptor {fd!r} is used by transport')

Expand Down
63 changes: 47 additions & 16 deletions src/event_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,9 @@ impl EventLoop {
fn handle_io_tcps(&self, event: &event::Event, handles_ready: &mut VecDeque<BoxedHandle>) {
let fd = event.token().0;
if event.is_readable() {
handles_ready.push_back(Box::new(TCPReadHandle {
fd,
closed: event.is_read_closed(),
}));
handles_ready.push_back(Box::new(TCPReadHandle { fd }));
} else if event.is_writable() {
handles_ready.push_back(Box::new(TCPWriteHandle {
fd,
closed: event.is_write_closed(),
}));
handles_ready.push_back(Box::new(TCPWriteHandle { fd }));
}
}

Expand Down Expand Up @@ -433,6 +427,7 @@ impl EventLoop {
v
});
}
transport.drop_ref(py);
}
}

Expand Down Expand Up @@ -461,6 +456,7 @@ impl EventLoop {
)
}

#[allow(clippy::missing_errors_doc)]
pub fn schedule0(&self, callback: PyObject, context: Option<PyObject>) -> Result<()> {
let handle = Python::with_gil(|py| {
Py::new(
Expand All @@ -469,7 +465,10 @@ impl EventLoop {
)
})?;
{
let mut guard = self.handles_ready.lock().unwrap();
let mut guard = self
.handles_ready
.lock()
.map_err(|_| anyhow::anyhow!("lock acquisition failed"))?;
guard.push_back(Box::new(handle));
}
self.counter_ready.fetch_add(1, atomic::Ordering::Release);
Expand All @@ -480,6 +479,7 @@ impl EventLoop {
Ok(())
}

#[allow(clippy::missing_errors_doc)]
pub fn schedule1(&self, callback: PyObject, arg: PyObject, context: Option<PyObject>) -> Result<()> {
let handle = Python::with_gil(|py| {
Py::new(
Expand All @@ -488,7 +488,10 @@ impl EventLoop {
)
})?;
{
let mut guard = self.handles_ready.lock().unwrap();
let mut guard = self
.handles_ready
.lock()
.map_err(|_| anyhow::anyhow!("lock acquisition failed"))?;
guard.push_back(Box::new(handle));
}
self.counter_ready.fetch_add(1, atomic::Ordering::Release);
Expand All @@ -499,6 +502,7 @@ impl EventLoop {
Ok(())
}

#[allow(clippy::missing_errors_doc)]
pub fn schedule(&self, callback: PyObject, args: PyObject, context: Option<PyObject>) -> Result<()> {
let handle = Python::with_gil(|py| {
Py::new(
Expand All @@ -507,7 +511,10 @@ impl EventLoop {
)
})?;
{
let mut guard = self.handles_ready.lock().unwrap();
let mut guard = self
.handles_ready
.lock()
.map_err(|_| anyhow::anyhow!("lock acquisition failed"))?;
guard.push_back(Box::new(handle));
}
self.counter_ready.fetch_add(1, atomic::Ordering::Release);
Expand All @@ -518,6 +525,7 @@ impl EventLoop {
Ok(())
}

#[allow(clippy::missing_errors_doc)]
pub fn schedule_later0(&self, delay: Duration, callback: PyObject, context: Option<PyObject>) -> Result<()> {
let when = (Instant::now().duration_since(self.epoch) + delay).as_micros();
let handle = Python::with_gil(|py| {
Expand All @@ -531,7 +539,10 @@ impl EventLoop {
when,
};
{
let mut guard = self.handles_sched.lock().unwrap();
let mut guard = self
.handles_sched
.lock()
.map_err(|_| anyhow::anyhow!("lock acquisition failed"))?;
guard.push(timer);
}
if self.idle.load(atomic::Ordering::Acquire) {
Expand All @@ -541,7 +552,14 @@ impl EventLoop {
Ok(())
}

pub fn schedule_later1(&self, delay: Duration, callback: PyObject, arg: PyObject, context: Option<PyObject>) -> Result<()> {
#[allow(clippy::missing_errors_doc)]
pub fn schedule_later1(
&self,
delay: Duration,
callback: PyObject,
arg: PyObject,
context: Option<PyObject>,
) -> Result<()> {
let when = (Instant::now().duration_since(self.epoch) + delay).as_micros();
let handle = Python::with_gil(|py| {
Py::new(
Expand All @@ -554,7 +572,10 @@ impl EventLoop {
when,
};
{
let mut guard = self.handles_sched.lock().unwrap();
let mut guard = self
.handles_sched
.lock()
.map_err(|_| anyhow::anyhow!("lock acquisition failed"))?;
guard.push(timer);
}
if self.idle.load(atomic::Ordering::Acquire) {
Expand All @@ -564,7 +585,14 @@ impl EventLoop {
Ok(())
}

pub fn schedule_later(&self, delay: Duration, callback: PyObject, args: PyObject, context: Option<PyObject>) -> Result<()> {
#[allow(clippy::missing_errors_doc)]
pub fn schedule_later(
&self,
delay: Duration,
callback: PyObject,
args: PyObject,
context: Option<PyObject>,
) -> Result<()> {
let when = (Instant::now().duration_since(self.epoch) + delay).as_micros();
let handle = Python::with_gil(|py| {
Py::new(
Expand All @@ -577,7 +605,10 @@ impl EventLoop {
when,
};
{
let mut guard = self.handles_sched.lock().unwrap();
let mut guard = self
.handles_sched
.lock()
.map_err(|_| anyhow::anyhow!("lock acquisition failed"))?;
guard.push(timer);
}
if self.idle.load(atomic::Ordering::Acquire) {
Expand Down
93 changes: 50 additions & 43 deletions src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,16 @@ impl TCPTransport {
fn close_from_read_handle(&self, py: Python, event_loop: &EventLoop) -> bool {
if self
.closing
.compare_exchange(false, true, atomic::Ordering::Release, atomic::Ordering::Relaxed)
.compare_exchange(false, true, atomic::Ordering::Relaxed, atomic::Ordering::Relaxed)
.is_err()
{
return false;
}

if !self.state.borrow_mut().write_buf.is_empty() {
return false;
}

event_loop.tcp_stream_rem(self.fd, Interest::WRITABLE);
_ = self.protom_conn_lost.call1(py, (py.None(),));
true
Expand Down Expand Up @@ -551,60 +555,64 @@ impl TCPTransport {

pub(crate) struct TCPReadHandle {
pub fd: usize,
pub closed: bool,
}

impl TCPReadHandle {
#[inline]
fn recv_direct(&self, py: Python, transport: &TCPTransport, buf: &mut [u8]) -> Option<(PyObject, bool)> {
match self.read_into(&mut transport.state.borrow_mut().stream, buf) {
0 => None,
read => {
let rbuf = &buf[..read];
let pydata = unsafe { PyBytes::from_ptr(py, rbuf.as_ptr(), read) };
Some((pydata.into_any().unbind(), read == buf.len()))
}
fn recv_direct(&self, py: Python, transport: &TCPTransport, buf: &mut [u8]) -> (Option<PyObject>, bool) {
let (read, closed) = self.read_into(&mut transport.state.borrow_mut().stream, buf);
if read > 0 {
let rbuf = &buf[..read];
let pydata = unsafe { PyBytes::from_ptr(py, rbuf.as_ptr(), read) };
return (Some(pydata.into_any().unbind()), closed);
}
(None, closed)
}

#[inline]
fn recv_buffered(&self, py: Python, transport: &TCPTransport) -> Option<(PyObject, bool)> {
fn recv_buffered(&self, py: Python, transport: &TCPTransport) -> (Option<PyObject>, bool) {
// NOTE: `PuBuffer.as_mut_slice` exists, but it returns a slice of `Cell<u8>`,
// which is smth we can't really use to read from `TcpStream`.
// So even if this sucks, we copy data back and forth, at least until
// we figure out a way to actually use `PyBuffer` directly.
let pybuf: PyBuffer<u8> = PyBuffer::get(&transport.protom_buf_get.bind(py).call1((-1,)).unwrap()).unwrap();
let mut vbuf = pybuf.to_vec(py).unwrap();
match self.read_into(&mut transport.state.borrow_mut().stream, vbuf.as_mut_slice()) {
0 => None,
read => {
_ = pybuf.copy_from_slice(py, &vbuf[..]);
Some((read.into_py_any(py).unwrap(), read == vbuf.len()))
}
let (read, closed) = self.read_into(&mut transport.state.borrow_mut().stream, vbuf.as_mut_slice());
if read > 0 {
_ = pybuf.copy_from_slice(py, &vbuf[..]);
return (Some(read.into_py_any(py).unwrap()), closed);
}
(None, closed)
}

#[inline(always)]
fn read_into(&self, stream: &mut TcpStream, buf: &mut [u8]) -> usize {
fn read_into(&self, stream: &mut TcpStream, buf: &mut [u8]) -> (usize, bool) {
let mut len = 0;
let mut closed = false;

loop {
match stream.read(&mut buf[len..]) {
Ok(readn) if readn != 0 => len += readn,
Ok(0) => {
if len < buf.len() {
closed = true;
}
break;
}
Ok(readn) => len += readn,
Err(err) if err.kind() == std::io::ErrorKind::Interrupted => continue,
_ => break,
}
}
len

(len, closed)
}

#[inline]
fn recv_eof(&self, py: Python, event_loop: &EventLoop, transport: &TCPTransport) -> bool {
event_loop.tcp_stream_rem(self.fd, Interest::READABLE);
if let Ok(pyr) = transport.proto.call_method0(py, pyo3::intern!(py, "eof_received")) {
if let Ok(false) = pyr.is_truthy(py) {
if !transport.state.borrow().write_buf.is_empty() {
return false;
}
if let Ok(true) = pyr.is_truthy(py) {
return false;
}
}
transport.close_from_read_handle(py, event_loop)
Expand All @@ -620,17 +628,19 @@ impl Handle for TCPReadHandle {
// otherwise we won't get another readable event from the poller
let mut close = false;
loop {
if let Some((data, more)) = match transport.proto_buffered {
let (data, eof) = match transport.proto_buffered {
true => self.recv_buffered(py, &transport),
false => self.recv_direct(py, &transport, &mut state.read_buf),
} {
};

if let Some(data) = data {
_ = transport.protom_recv_data.call1(py, (data,));
if more {
if !eof {
continue;
}
}

if self.closed {
if eof {
close = self.recv_eof(py, event_loop, &transport);
}

Expand All @@ -645,33 +655,37 @@ impl Handle for TCPReadHandle {

pub(crate) struct TCPWriteHandle {
pub fd: usize,
pub closed: bool,
}

impl TCPWriteHandle {
#[inline]
fn write(&self, transport: &TCPTransport) -> Option<usize> {
#[allow(clippy::cast_possible_wrap)]
let fd = self.fd as i32;
let mut ret = 0;
let mut state = transport.state.borrow_mut();
while let Some(data) = state.write_buf.pop_front() {
match syscall!(write(fd, data.as_ptr().cast(), data.len())) {
Ok(written) if written as usize != data.len() => {
Ok(written) if (written as usize) < data.len() => {
let written = written as usize;
state.write_buf.push_front((&data[written..]).into());
ret += written;
break;
}
Ok(written) => ret += written as usize,
Err(err) if err.kind() != std::io::ErrorKind::Interrupted => {
state.write_buf.clear();
state.write_buf_dsize = 0;
return None;
Err(err) if err.kind() == std::io::ErrorKind::Interrupted => {
state.write_buf.push_front(data);
continue;
}
_ => {
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {
state.write_buf.push_front(data);
break;
}
_ => {
state.write_buf.clear();
state.write_buf_dsize = 0;
return None;
}
}
}
state.write_buf_dsize -= ret;
Expand All @@ -685,14 +699,7 @@ impl Handle for TCPWriteHandle {
let transport = pytransport.borrow(py);
let stream_close;

if self.closed {
{
let mut transport_state = transport.state.borrow_mut();
transport_state.write_buf.clear();
transport_state.write_buf_dsize = 0;
}
stream_close = transport.close_from_write_handle(py, false);
} else if let Some(written) = self.write(&transport) {
if let Some(written) = self.write(&transport) {
if written > 0 {
TCPTransport::write_buf_size_decr(&pytransport, py);
}
Expand Down
3 changes: 3 additions & 0 deletions tests/tcp/test_tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ def client(addr):
loop.run_until_complete(main())
assert proto.state == 'CLOSED'
assert state['data'] == msg


# TODO: test buffered proto
File renamed without changes.

0 comments on commit 09b573f

Please sign in to comment.