Skip to content

Commit

Permalink
Fix to avoid write after closes (fixes #16)
Browse files Browse the repository at this point in the history
  • Loading branch information
polytypic committed Feb 4, 2024
1 parent 95e1d73 commit 57eae80
Showing 1 changed file with 85 additions and 52 deletions.
137 changes: 85 additions & 52 deletions src/Domain_local_timeout.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,101 +26,135 @@ end

module Q = Psq.Make (Int) (Entry)

exception Running
exception Stopped

let shared_byte = Bytes.create 1

let system_on_current_domain (module Thread : Thread) (module Unix : Unix) =
let error = ref None in
let check () = match !error with None -> () | Some exn -> raise exn in
let running = ref true in
let needs_wakeup = ref true in
let reading, writing = Unix.pipe () in
let[@poll error] [@inline never] wakeup_needed_atomically () =
!needs_wakeup && !error == None
let open struct
type state = {
mutable needs_wakeup : bool;
mutable counter : int;
mutable status : exn;
reading : Unix.file_descr;
writing : Unix.file_descr;
timeouts : Q.t Atomic.t;
}
end in
let s =
let reading, writing = Unix.pipe () in
{
needs_wakeup = true;
counter = 0;
status = Running;
reading;
writing;
timeouts = Atomic.make Q.empty;
}
in
let[@poll error] [@inline never] wakeup_needed_atomically s status =
s.needs_wakeup && s.status == status
&& begin
needs_wakeup := false;
s.needs_wakeup <- false;
true
end
in
let wakeup () =
if wakeup_needed_atomically () then begin
let n = Unix.write writing (Bytes.create 1) 0 1 in
let wakeup s status =
if wakeup_needed_atomically s status then begin
let n = Unix.write s.writing shared_byte 0 1 in
assert (n = 1)
end
in
let counter = ref 0 in
let[@poll error] [@inline never] next_id_atomically () =
let id = !counter + 1 in
counter := id;
let[@poll error] [@inline never] next_id_atomically s =
let id = s.counter + 1 in
s.counter <- id;
id
in
let timeouts = Atomic.make Q.empty in
let[@poll error] [@inline never] running_atomically () =
!running
let[@poll error] [@inline never] stop_atomically s =
s.status == Running
&& begin
needs_wakeup := true;
s.status <- Stopped;
true
end
in
let rec timeout_thread next =
if running_atomically () then begin
begin
match Unix.select [ reading ] [] [] next with
let[@poll error] [@inline never] running_atomically s =
let running = s.status == Running in
s.needs_wakeup <- running;
running
in
let rec timeout_thread s ts_old next =
if running_atomically s then begin
if ts_old == Atomic.get s.timeouts then begin
match Unix.select [ s.reading ] [] [] next with
| [ reading ], _, _ ->
let n = Unix.read reading (Bytes.create 1) 0 1 in
assert (n = 1)
| _, _, _ -> ()
end;
let rec loop () =
let ts_old = Atomic.get timeouts in
s.needs_wakeup <- false;
let rec loop s =
let ts_old = Atomic.get s.timeouts in
match Q.pop ts_old with
| None -> -1.0
| None -> timeout_thread s ts_old (-1.0)
| Some ((_, t), ts) ->
let elapsed = Mtime_clock.elapsed () in
if Mtime.Span.compare t.time elapsed <= 0 then begin
if Atomic.compare_and_set timeouts ts_old ts then t.action ();
loop ()
if Atomic.compare_and_set s.timeouts ts_old ts then t.action ();
loop s
end
else
Mtime.Span.to_float_ns (Mtime.Span.abs_diff t.time elapsed)
*. (1. /. 1_000_000_000.)
let next =
Mtime.Span.to_float_ns (Mtime.Span.abs_diff t.time elapsed)
*. (1. /. 1_000_000_000.)
in
timeout_thread s ts_old next
in
timeout_thread (loop ())
loop s
end
in
let timeout_thread () =
let timeout_thread s =
begin
match timeout_thread (-1.0) with
match timeout_thread s Q.empty (-1.0) with
| () -> ()
| exception exn -> error := Some exn
| exception exn -> s.status <- exn
end;
Unix.close reading;
Unix.close writing
(* At this point [needs_wakeup = false]. *)
Atomic.set s.timeouts Q.empty
in
let tid = Thread.create timeout_thread () in
let tid = Thread.create timeout_thread s in
let stop () =
running := false;
wakeup ();
if stop_atomically s then wakeup s Stopped;
Thread.join tid;
check ()
Unix.close s.reading;
Unix.close s.writing;
match s.status with Stopped -> () | exn -> raise exn
in
let set_timeoutf seconds action =
match Mtime.Span.of_float_ns (seconds *. 1_000_000_000.) with
| None ->
invalid_arg "timeout should be between 0 to pow(2, 53) nanoseconds"
| Some span ->
check ();
let time = Mtime.Span.add (Mtime_clock.elapsed ()) span in
let e' = Entry.{ time; action } in
let id = next_id_atomically () in
let rec insert_loop () =
let ts = Atomic.get timeouts in
let id = next_id_atomically s in
let rec insert_loop s id e' =
let ts = Atomic.get s.timeouts in
let ts' = Q.add id e' ts in
if not (Atomic.compare_and_set timeouts ts ts') then insert_loop ()
else match Q.min ts' with Some (id', _) -> id = id' | None -> false
match s.status with
| Running ->
if not (Atomic.compare_and_set s.timeouts ts ts') then
insert_loop s id e'
else begin
match Q.min ts' with Some (id', _) -> id = id' | None -> false
end
| exn -> raise exn
in
if insert_loop () then wakeup ();
if insert_loop s id e' then wakeup s Running;
let rec cancel () =
let ts = Atomic.get timeouts in
let ts = Atomic.get s.timeouts in
let ts' = Q.remove id ts in
if not (Atomic.compare_and_set timeouts ts ts') then cancel ()
if not (Atomic.compare_and_set s.timeouts ts ts') then cancel ()
in
cancel
in
Expand All @@ -144,9 +178,8 @@ let try_system = ref unimplemented
let default seconds action = !try_system seconds action
let key = Domain.DLS.new_key @@ fun () -> Per_domain { set_timeoutf = default }

let[@poll error] [@inline never] update_set_timeoutf_atomically state
set_timeoutf =
match state with
let[@poll error] [@inline never] update_set_timeoutf_atomically s set_timeoutf =
match s with
| Per_domain r ->
let current = r.set_timeoutf in
if current == default then begin
Expand Down

0 comments on commit 57eae80

Please sign in to comment.