Skip to content

Commit

Permalink
fix: take RunOptions by reference
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jun 30, 2024
1 parent 2140455 commit 27676fc
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/training/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl Trainer {
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'s>> {
) -> Result<SessionOutputs<'_, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
Expand All @@ -112,11 +112,11 @@ impl Trainer {
}
}

fn step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
run_options: Option<Arc<RunOptions>>
) -> Result<SessionOutputs<'s>> {
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r, 's>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];

let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();
Expand Down Expand Up @@ -145,7 +145,7 @@ impl Trainer {
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'s>> {
) -> Result<SessionOutputs<'_, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
Expand All @@ -161,11 +161,11 @@ impl Trainer {
}
}

fn eval_step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>(
&'s self,
input_values: impl Iterator<Item = &'i1 SessionInputValue<'v1>>,
run_options: Option<Arc<RunOptions>>
) -> Result<SessionOutputs<'s>> {
run_options: Option<&'r RunOptions>
) -> Result<SessionOutputs<'r, 's>> {
let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()];

let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();
Expand Down

0 comments on commit 27676fc

Please sign in to comment.