Source code for CyRK.helper

""" Helper function to convert conventional solve_ivp or numba diffeq's into a format that cyrk can accept. """
import numpy as np
from numba import njit

from CyRK.cy.common import CyrkErrorCodes
from CyRK.cy.cysolver_api import ODEMethod

[docs] def nb2cy(diffeq: callable, use_njit: bool = True, cache_njit: bool = False) -> callable: """ Convert numba/scipy differential equation functions to the cyrk format. Parameters ---------- diffeq : callable Differential equation function. use_njit : bool = True If True, the final function will be njited. cache_njit : bool = False If True, then the njit-complied function will be cached. Returns ------- diffeq_cyrk : callable cyrk-safe differential equation function. """ if use_njit: if cache_njit: njit_ = njit(cache=True) else: njit_ = njit(cache=False) else: def njit_(func): return func @njit_ def diffeq_cyrk(dy, t, y, *args): # Cython integrator requires the arguments to be passed as input args dy_ = diffeq(t, y, *args) # Set the input dy items equal to the output for i in range(y.size): dy[i] = dy_[i] return diffeq_cyrk
[docs] def cy2nb(diffeq: callable, use_njit: bool = True, cache_njit: bool = False) -> callable: """ Convert cyrk differential equation functions to the numba/scipy format. Parameters ---------- diffeq : callable Differential equation function. use_njit : bool = True If True, the final function will be njited. cache_njit : bool = False If True, then the njit-complied function will be cached. Returns ------- diffeq_nbrk : callable numba/scipy-safe differential equation function. """ if use_njit: if cache_njit: njit_ = njit(cache=True) else: njit_ = njit(cache=False) else: def njit_(func): return func @njit_ def diffeq_nbrk(t, y, *args): # Cython integrator requires the arguments to be passed as input args dy = np.empty_like(y) diffeq(dy, t, y, *args) return dy return diffeq_nbrk