Skip to content

Commit

Permalink
CySolver now installs; tests added.
Browse files Browse the repository at this point in the history
  • Loading branch information
jrenaud90 committed Jul 27, 2023
1 parent 92945bc commit 44799a3
Show file tree
Hide file tree
Showing 14 changed files with 1,365 additions and 329 deletions.
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
### v0.6.0

New Features
- Created the `CyRKSolver` class which is more efficient than the `cyrk_ode` function.
- Created the `CySolver` class which is more efficient than the `cyrk_ode` function.
- New functions in `CyRK.cy.cysolvertest` to help test and check performance of `CySolver`.

Bug Fixes:
- Fixed compile error with `cyrk_ode` "complex types are unordered".
Expand Down
2 changes: 1 addition & 1 deletion CyRK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from .helper import nb2cy, cy2nb

# Import test functions
from ._test import test_cyrk, test_nbrk
from ._test import test_cyrk, test_nbrk, test_cysolver
14 changes: 14 additions & 0 deletions CyRK/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,17 @@ def test_nbrk():
assert y_results.shape[0] == 2

print("CyRK's nbrk_ode was tested successfully.")


def test_cysolver():

from CyRK.cy.cysolvertest import CySolverTester

# TODO: Currently CySolver only works with floats not complex
CySolverTesterInst = CySolverTester(time_span, np.asarray(initial_conds, dtype=np.float64))
CySolverTesterInst.solve()

assert CySolverTesterInst.success
assert type(CySolverTesterInst.solution_t) == np.ndarray
assert type(CySolverTesterInst.solution_y) == np.ndarray
assert CySolverTesterInst.solution_y.shape[0] == 2
Empty file added CyRK/cy/__init__.pxd
Empty file.
Empty file added CyRK/cy/__init__.pyx
Empty file.
58 changes: 58 additions & 0 deletions CyRK/cy/cysolver.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
cimport numpy as np
from libcpp cimport bool as bool_cpp_t
cdef class CySolver:

# Class attributes
# -- Live variables
cdef double t_new, t_old
cdef unsigned int len_t
cdef double[:] y_new_view, y_old_view, dy_new_view, dy_old_view
cdef double[:] extra_output_view, extra_output_init_view

# -- Dependent (y0) variable information
cdef unsigned short y_size
cdef double y_size_dbl, y_size_sqrt
cdef const double[:] y0_view

# -- RK method information
cdef unsigned char rk_method
cdef unsigned char rk_order, error_order, rk_n_stages, rk_n_stages_plus1, rk_n_stages_extended
cdef double error_expo
cdef unsigned char len_C
cdef double[:] B_view, E_view, E3_view, E5_view, E_tmp_view, E3_tmp_view, E5_tmp_view, C_view
cdef double[:, :] A_view, K_view

# -- Integration information
cdef public char status
cdef public str message
cdef public bool_cpp_t success
cdef double t_start, t_end, t_delta, t_delta_abs, direction, direction_inf
cdef double rtol, atol
cdef double step_size, max_step
cdef unsigned int expected_size
cdef unsigned int num_concats

# -- Optional args info
cdef unsigned short num_args
cdef double[:] arg_array_view

# -- Extra output info
cdef bool_cpp_t capture_extra
cdef unsigned short num_extra

# -- Interpolation info
cdef bool_cpp_t run_interpolation
cdef bool_cpp_t interpolate_extra
cdef unsigned int len_t_eval
cdef double[:] t_eval_view

# -- Solution variables
cdef double[:, :] solution_y_view, solution_extra_view
cdef double[:] solution_t_view

# Class functions
cdef double calc_first_step(self)
cpdef void solve(self)
cdef void _solve(self)
cdef void interpolate(self)
cdef void diffeq(self)
75 changes: 15 additions & 60 deletions CyRK/cy/_cyrk_class.pyx → CyRK/cy/cysolver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,63 +28,7 @@ cdef double EPS = np.finfo(np.float64).eps
cdef double EPS_10 = EPS * 10.
cdef double EPS_100 = EPS * 100.

cdef class CyRKSolver:

# TODO:
# - y_old and dy_old do not need to be class variables. they can be inside the solver alone.
# - Add @cython.exceptval(check=False) and @cython.initializedcheck(False) everywhere.


# def solve_complex

# def solve

# -- Live variables
cdef double t_new, t_old
cdef unsigned int len_t
cdef double[:] y_new_view, y_old_view, dy_new_view, dy_old_view
cdef double[:] extra_output_view, extra_output_init_view

# -- Dependent (y0) variable information
cdef unsigned short y_size
cdef double y_size_dbl, y_size_sqrt
cdef const double[:] y0_view

# -- RK method information
cdef unsigned char rk_method
cdef unsigned char rk_order, error_order, rk_n_stages, rk_n_stages_plus1, rk_n_stages_extended
cdef double error_expo
cdef unsigned char len_C
cdef double[:] B_view, E_view, E3_view, E5_view, E_tmp_view, E3_tmp_view, E5_tmp_view, C_view
cdef double[:, :] A_view, K_view

# -- Integration information
cdef public char status
cdef public str message
cdef public bool_cpp_t success
cdef double t_start, t_end, t_delta, t_delta_abs, direction, direction_inf
cdef double rtol, atol
cdef double step_size, max_step
cdef unsigned int expected_size
cdef unsigned int num_concats

# -- Optional args info
cdef unsigned short num_args
cdef double[:] arg_array_view

# -- Extra output info
cdef bool_cpp_t capture_extra
cdef unsigned short num_extra

# -- Interpolation info
cdef bool_cpp_t run_interpolation
cdef bool_cpp_t interpolate_extra
cdef unsigned int len_t_eval
cdef double[:] t_eval_view

# -- Solution variables
cdef double[:, :] solution_y_view, solution_extra_view
cdef double[:] solution_t_view
cdef class CySolver:

def __init__(self,
(double, double) t_span,
Expand All @@ -108,6 +52,16 @@ cdef class CyRKSolver:
self.status = -3 # Status code to indicate that integration has not started.
self.message = 'Integration has not started.'
self.success = False

# Declare public variables to avoid memory access violations if solve() is not called.
cdef np.ndarray[np.float64_t, ndim=2, mode='c'] solution_extra_fake, solution_y_fake
cdef np.ndarray[np.float64_t, ndim=1, mode='c'] solution_t_fake
solution_extra_fake = np.nan * np.ones((1, 1), dtype=np.float64, order='C')
solution_y_fake = np.nan * np.ones((1, 1), dtype=np.float64, order='C')
solution_t_fake = np.nan * np.ones(1, dtype=np.float64, order='C')
self.solution_t_view = solution_t_fake
self.solution_extra_view = solution_extra_fake
self.solution_y_view = solution_y_fake

# Expected size of output arrays.
self.expected_size = expected_size
Expand Down Expand Up @@ -364,7 +318,7 @@ cdef class CyRKSolver:
self.max_step = max_step

@cython.exceptval(check=False)
cdef double calc_first_step(self) nogil:
cdef double calc_first_step(self):
""" Determine initial step size. """

cdef double step_size, d0, d1, d2, d0_abs, d1_abs, d2_abs, h0, h1, scale
Expand Down Expand Up @@ -423,7 +377,7 @@ cdef class CyRKSolver:
return step_size

@cython.exceptval(check=False)
def solve(self):
cpdef void solve(self):
self._solve()

@cython.exceptval(check=False)
Expand Down Expand Up @@ -724,6 +678,7 @@ cdef class CyRKSolver:
# No longer need the old arrays. Change where the view is pointing and delete them.
y_results_array_view = y_results_array_new
time_domain_array_view = time_domain_array_new
# TODO
# del y_results_array
# del time_domain_array
if self.capture_extra:
Expand Down Expand Up @@ -908,7 +863,7 @@ cdef class CyRKSolver:


@cython.exceptval(check=False)
cdef void diffeq(self) nogil:
cdef void diffeq(self):
# This is a template function that should be overriden by the user's subclass.

# The diffeq can use live variables:
Expand Down
58 changes: 58 additions & 0 deletions CyRK/cy/cysolvertest.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# distutils: language = c++
# cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True, initializedcheck=False

import cython
import numpy as np
cimport numpy as np
np.import_array()
from libc.math cimport sin, cos

from CyRK.cy.cysolver cimport CySolver


cdef class CySolverTester(CySolver):

@cython.exceptval(check=False)
cdef void diffeq(self):

# Unpack y
cdef double y0, y1
y0 = self.y_new_view[0]
y1 = self.y_new_view[1]

self.dy_new_view[0] = (1. - 0.01 * y1) * y0
self.dy_new_view[1] = (0.02 * y0 - 1.) * y1


cdef class CySolverAccuracyTest(CySolver):

@cython.exceptval(check=False)
cdef void diffeq(self):

# Unpack y
cdef double y0, y1
y0 = self.y_new_view[0]
y1 = self.y_new_view[1]

self.dy_new_view[0] = sin(self.t_new) - y1
self.dy_new_view[1] = cos(self.t_new) + y0


cdef class CySolverExtraTest(CySolver):

@cython.exceptval(check=False)
cdef void diffeq(self):

# Unpack y
cdef double y0, y1, extra_0, extra_1
y0 = self.y_new_view[0]
y1 = self.y_new_view[1]

extra_0 = (1. - 0.01 * y1)
extra_1 = (0.02 * y0 - 1.)

self.dy_new_view[0] = extra_0 * y0
self.dy_new_view[1] = extra_1 * y1

self.extra_output_view[0] = extra_0
self.extra_output_view[1] = extra_1
Loading

0 comments on commit 44799a3

Please sign in to comment.