Skip to content

Commit

Permalink
working on pre-eval
Browse files Browse the repository at this point in the history
  • Loading branch information
jrenaud90 committed Jul 22, 2024
1 parent 965e928 commit e8d6ce9
Show file tree
Hide file tree
Showing 16 changed files with 335 additions and 84 deletions.
5 changes: 4 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
C++ Backend:
* Changed optional args from double pointer to void pointers to allow for arbitrary objects to be passed in.
* Added description of this feature to "Documentation/Advanced CySolver.md" documentation and "Demos/Advanced CySolver Examples.ipynb" jupyter notebook.
* Allow users to specify a "Pre-Eval" function that can be passed to the differential equation. This function should take in time, y, and args and update an output pointer which can then be used by the diffeq to solve for dydt.

`cysolve_ivp`:
* Change call signature to accept new `pre_eval_func` function.
* Added more differential equations to tests.
* Added tests to check new void arg system.
* Added tests to check new void arg feature.
* Added tests to check new pre-eval function feature.

### v0.10.0 (2024-07-17)

Expand Down
6 changes: 5 additions & 1 deletion CyRK/cy/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ static const double DYNAMIC_GROWTH_RATE = 1.618;
static constexpr double SIZE_MAX_DBL = 0.99 * SIZE_MAX;


typedef void (*DiffeqFuncType)(double*, double, double*, const void*);

typedef void (*PreEvalFunc)(void*, double, double*, const void*);

typedef void (*DiffeqFuncType)(double*, double, double*, const void*, PreEvalFunc);



struct MaxNumStepsOutput
Expand Down
13 changes: 10 additions & 3 deletions CyRK/cy/cysolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ std::shared_ptr<CySolverResult> baseline_cysolve_ivp(
const bool dense_output,
const double* t_eval,
const size_t len_t_eval,
PreEvalFunc pre_eval_func,
// rk optional arguments
const double rtol,
const double atol,
Expand Down Expand Up @@ -85,7 +86,7 @@ std::shared_ptr<CySolverResult> baseline_cysolve_ivp(
solver = new RK23(
// Common Inputs
diffeq_ptr, solution_ptr, t_start, t_end, y0_ptr, num_y, num_extra, args_ptr, max_num_steps, max_ram_MB,
dense_output, t_eval, len_t_eval,
dense_output, t_eval, len_t_eval, pre_eval_func,
// RK Inputs
rtol, atol, rtols_ptr, atols_ptr, max_step_size, first_step_size
);
Expand All @@ -95,7 +96,7 @@ std::shared_ptr<CySolverResult> baseline_cysolve_ivp(
solver = new RK45(
// Common Inputs
diffeq_ptr, solution_ptr, t_start, t_end, y0_ptr, num_y, num_extra, args_ptr, max_num_steps, max_ram_MB,
dense_output, t_eval, len_t_eval,
dense_output, t_eval, len_t_eval, pre_eval_func,
// RK Inputs
rtol, atol, rtols_ptr, atols_ptr, max_step_size, first_step_size
);
Expand All @@ -105,7 +106,7 @@ std::shared_ptr<CySolverResult> baseline_cysolve_ivp(
solver = new DOP853(
// Common Inputs
diffeq_ptr, solution_ptr, t_start, t_end, y0_ptr, num_y, num_extra, args_ptr, max_num_steps, max_ram_MB,
dense_output, t_eval, len_t_eval,
dense_output, t_eval, len_t_eval, pre_eval_func,
// RK Inputs
rtol, atol, rtols_ptr, atols_ptr, max_step_size, first_step_size
);
Expand Down Expand Up @@ -181,6 +182,9 @@ PySolver::PySolver(
// We need to pass a fake diffeq pointer (diffeq ptr is unused in python-based solver)
DiffeqFuncType diffeq_ptr = nullptr;

// We also need to pass a fake pre-eval function
PreEvalFunc pre_eval_func = nullptr;

// Build the solver class. This must be heap allocated to take advantage of polymorphism.
switch (this->integration_method)
{
Expand All @@ -201,6 +205,7 @@ PySolver::PySolver(
dense_output,
t_eval,
len_t_eval,
pre_eval_func,
// rk optional arguments
rtol,
atol,
Expand All @@ -226,6 +231,7 @@ PySolver::PySolver(
dense_output,
t_eval,
len_t_eval,
pre_eval_func,
// rk optional arguments
rtol,
atol,
Expand All @@ -251,6 +257,7 @@ PySolver::PySolver(
dense_output,
t_eval,
len_t_eval,
pre_eval_func,
// rk optional arguments
rtol,
atol,
Expand Down
1 change: 1 addition & 0 deletions CyRK/cy/cysolve.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ std::shared_ptr<CySolverResult> baseline_cysolve_ivp(
const bool dense_output = false,
const double* t_eval = nullptr,
const size_t len_t_eval = 0,
PreEvalFunc pre_eval_func = nullptr,
// rk optional arguments
const double rtol = 1.0e-3,
const double atol = 1.0e-6,
Expand Down
8 changes: 5 additions & 3 deletions CyRK/cy/cysolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ CySolverBase::CySolverBase(
const size_t max_ram_MB,
const bool use_dense_output,
const double* t_eval,
const size_t len_t_eval) :
const size_t len_t_eval,
PreEvalFunc pre_eval_func) :
status(0),
num_y(num_y),
num_extra(num_extra),
Expand All @@ -48,7 +49,8 @@ CySolverBase::CySolverBase(
diffeq_ptr(diffeq_ptr),
args_ptr(args_ptr),
use_dense_output(use_dense_output),
len_t_eval(len_t_eval)
len_t_eval(len_t_eval),
pre_eval_func(pre_eval_func)
{
// Parse inputs
this->capture_extra = num_extra > 0;
Expand Down Expand Up @@ -194,7 +196,7 @@ bool CySolverBase::check_status() const
void CySolverBase::cy_diffeq() noexcept
{
// Call c function
this->diffeq_ptr(this->dy_now_ptr, this->t_now_ptr[0], this->y_now_ptr, this->args_ptr);
this->diffeq_ptr(this->dy_now_ptr, this->t_now_ptr[0], this->y_now_ptr, this->args_ptr, this->pre_eval_func);
}

void CySolverBase::reset()
Expand Down
6 changes: 5 additions & 1 deletion CyRK/cy/cysolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class CySolverBase {
const size_t max_ram_MB = 2000,
const bool use_dense_output = false,
const double* t_eval = nullptr,
const size_t len_t_eval = 0
const size_t len_t_eval = 0,
PreEvalFunc pre_eval_func = nullptr
);

void change_storage(std::shared_ptr<CySolverResult> new_storage_ptr, bool auto_reset = true);
Expand Down Expand Up @@ -114,6 +115,9 @@ class CySolverBase {
// Information on capturing extra information during integration.
int num_extra = 0;

// Function to send to diffeq which is called before dy is calculated
PreEvalFunc pre_eval_func = nullptr;

// Keep bools together to reduce size
bool direction_flag = false;
bool reset_called = false;
Expand Down
13 changes: 11 additions & 2 deletions CyRK/cy/cysolverNew.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ cdef extern from "common.cpp" nogil:
const unsigned int DY_LIMIT
const double MAX_STEP

ctypedef void (*DiffeqFuncType)(double*, double, double*, const void*)
ctypedef void (*PreEvalFunc)(void*, double, double*, const void*)
ctypedef void (*DiffeqFuncType)(double*, double, double*, const void*, PreEvalFunc)

cdef size_t find_expected_size(
int num_y,
Expand Down Expand Up @@ -121,7 +122,8 @@ cdef extern from "cysolver.cpp" nogil:
const size_t max_ram_MB,
const cpp_bool use_dense_output,
const double* t_eval,
const size_t len_t_eval
const size_t len_t_eval,
PreEvalFunc pre_eval_func
)

shared_ptr[CySolverResult] storage_ptr
Expand Down Expand Up @@ -190,6 +192,7 @@ cdef extern from "rk.cpp" nogil:
const cpp_bool use_dense_output,
const double* t_eval,
const size_t len_t_eval,
PreEvalFunc pre_eval_func,
const double rtol,
const double atol,
const double* rtols_ptr,
Expand Down Expand Up @@ -220,6 +223,7 @@ cdef extern from "rk.cpp" nogil:
const cpp_bool use_dense_output,
const double* t_eval,
const size_t len_t_eval,
PreEvalFunc pre_eval_func,
const double rtol,
const double atol,
const double* rtols_ptr,
Expand All @@ -246,6 +250,7 @@ cdef extern from "rk.cpp" nogil:
const cpp_bool use_dense_output,
const double* t_eval,
const size_t len_t_eval,
PreEvalFunc pre_eval_func,
const double rtol,
const double atol,
const double* rtols_ptr,
Expand All @@ -272,6 +277,7 @@ cdef extern from "rk.cpp" nogil:
const cpp_bool use_dense_output,
const double* t_eval,
const size_t len_t_eval,
PreEvalFunc pre_eval_func,
const double rtol,
const double atol,
const double* rtols_ptr,
Expand Down Expand Up @@ -301,6 +307,7 @@ cdef extern from "cysolve.cpp" nogil:
const cpp_bool dense_output,
const double* t_eval,
const size_t len_t_eval,
PreEvalFunc pre_eval_func,
const double rtol,
const double atol,
const double* rtols_ptr,
Expand Down Expand Up @@ -361,6 +368,7 @@ cdef CySolveOutput cysolve_ivp(
bint dense_output = *,
double* t_eval = *,
size_t len_t_eval = *,
PreEvalFunc pre_eval_func = *,
double* rtols_ptr = *,
double* atols_ptr = *,
double max_step = *,
Expand All @@ -383,6 +391,7 @@ cdef CySolveOutput cysolve_ivp_gil(
bint dense_output = *,
double* t_eval = *,
size_t len_t_eval = *,
PreEvalFunc pre_eval_func = *,
double* rtols_ptr = *,
double* atols_ptr = *,
double max_step = *,
Expand Down
4 changes: 4 additions & 0 deletions CyRK/cy/cysolverNew.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ cdef CySolveOutput cysolve_ivp(
bint dense_output = False,
double* t_eval = NULL,
size_t len_t_eval = 0,
PreEvalFunc pre_eval_func = NULL,
double* rtols_ptr = NULL,
double* atols_ptr = NULL,
double max_step = MAX_STEP,
Expand All @@ -124,6 +125,7 @@ cdef CySolveOutput cysolve_ivp(
dense_output,
t_eval,
len_t_eval,
pre_eval_func,
rtol,
atol,
rtols_ptr,
Expand All @@ -149,6 +151,7 @@ cdef CySolveOutput cysolve_ivp_gil(
bint dense_output = False,
double* t_eval = NULL,
size_t len_t_eval = 0,
PreEvalFunc pre_eval_func = NULL,
double* rtols_ptr = NULL,
double* atols_ptr = NULL,
double max_step = MAX_STEP,
Expand All @@ -170,6 +173,7 @@ cdef CySolveOutput cysolve_ivp_gil(
dense_output,
t_eval,
len_t_eval,
pre_eval_func,
rtol,
atol,
rtols_ptr,
Expand Down
18 changes: 11 additions & 7 deletions CyRK/cy/cysolverNew_test.pxd
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from libcpp cimport bool as cpp_bool

cdef void baseline_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
cdef void accuracy_test_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
cdef void extraoutput_test_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
cdef void lorenz_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
cdef void lorenz_extraoutput_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
cdef void lotkavolterra_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
cdef void pendulum_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr) noexcept nogil
from libc.math cimport sin, cos, fabs, fmin, fmax

from CyRK.cy.cysolverNew cimport cysolve_ivp, WrapCySolverResult, DiffeqFuncType,MAX_STEP, CySolveOutput, PreEvalFunc

cdef void baseline_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
cdef void accuracy_test_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
cdef void extraoutput_test_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
cdef void lorenz_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
cdef void lorenz_extraoutput_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
cdef void lotkavolterra_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
cdef void pendulum_diffeq(double* dy_ptr, double t, double* y_ptr, const void* args_ptr, PreEvalFunc pre_eval_func) noexcept nogil
Loading

0 comments on commit e8d6ce9

Please sign in to comment.