Skip to content

Commit 8583586

Browse files
authored
Merge pull request #118 from Kai-Striega/enh/nper/broadcast
ENH: nper: broadcast rework with Cython
2 parents d60bf6d + 6b6f7b5 commit 8583586

File tree

3 files changed

+95
-29
lines changed

3 files changed

+95
-29
lines changed

numpy_financial/_cfinancial.pyx

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,59 @@
1-
from libc.math cimport NAN
1+
from libc.math cimport NAN, INFINITY, log
22
cimport cython
33

4+
5+
cdef double nper_inner_loop(
6+
const double rate_,
7+
const double pmt_,
8+
const double pv_,
9+
const double fv_,
10+
const double when_
11+
) nogil:
12+
if rate_ == 0.0 and pmt_ == 0.0:
13+
return INFINITY
14+
15+
if rate_ == 0.0:
16+
return -(fv_ + pv_) / pmt_
17+
18+
if rate_ <= -1.0:
19+
return NAN
20+
21+
z = pmt_ * (1.0 + rate_ * when_) / rate_
22+
return log((-fv_ + z) / (pv_ + z)) / log(1.0 + rate_)
23+
24+
25+
@cython.boundscheck(False)
26+
@cython.wraparound(False)
27+
def nper(
28+
const double[::1] rates,
29+
const double[::1] pmts,
30+
const double[::1] pvs,
31+
const double[::1] fvs,
32+
const double[::1] whens,
33+
double[:, :, :, :, ::1] out):
34+
35+
cdef:
36+
Py_ssize_t rate_, pmt_, pv_, fv_, when_
37+
38+
for rate_ in range(rates.shape[0]):
39+
for pmt_ in range(pmts.shape[0]):
40+
for pv_ in range(pvs.shape[0]):
41+
for fv_ in range(fvs.shape[0]):
42+
for when_ in range(whens.shape[0]):
43+
# We can have several ``ZeroDivisionErrors``s here
44+
# At the moment we want to replicate the existing function as
45+
# closely as possible however we should return financially
46+
# sensible results here.
47+
try:
48+
res = nper_inner_loop(
49+
rates[rate_], pmts[pmt_], pvs[pv_], fvs[fv_], whens[when_]
50+
)
51+
except ZeroDivisionError:
52+
res = NAN
53+
54+
out[rate_, pmt_, pv_, fv_, when_] = res
55+
56+
457
@cython.boundscheck(False)
558
@cython.cdivision(True)
659
def npv(const double[::1] rates, const double[:, ::1] values, double[:, ::1] out):

numpy_financial/_financial.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -306,35 +306,33 @@ def nper(rate, pmt, pv, fv=0, when='end'):
306306
The same analysis could be done with several different interest rates
307307
and/or payments and/or total amounts to produce an entire table.
308308
309-
>>> npf.nper(*(np.ogrid[0.07/12: 0.08/12: 0.01/12,
310-
... -150 : -99 : 50 ,
311-
... 8000 : 9001 : 1000]))
312-
array([[[ 64.07334877, 74.06368256],
313-
[108.07548412, 127.99022654]],
309+
>>> rates = [0.05, 0.06, 0.07]
310+
>>> payments = [100, 200, 300]
311+
>>> amounts = [7_000, 8_000, 9_000]
312+
>>> npf.nper(rates, payments, amounts).round(3)
313+
array([[[-30.827, -32.987, -34.94 ],
314+
[-20.734, -22.517, -24.158],
315+
[-15.847, -17.366, -18.78 ]],
314316
<BLANKLINE>
315-
[[ 66.12443902, 76.87897353],
316-
[114.70165583, 137.90124779]]])
317+
[[-28.294, -30.168, -31.857],
318+
[-19.417, -21.002, -22.453],
319+
[-15.025, -16.398, -17.67 ]],
320+
<BLANKLINE>
321+
[[-26.234, -27.891, -29.381],
322+
[-18.303, -19.731, -21.034],
323+
[-14.311, -15.566, -16.722]]])
317324
"""
318325
when = _convert_when(when)
319-
rate, pmt, pv, fv, when = np.broadcast_arrays(rate, pmt, pv, fv, when)
320-
nper_array = np.empty_like(rate, dtype=np.float64)
321-
322-
zero = rate == 0
323-
nonzero = ~zero
324-
325-
with np.errstate(divide='ignore'):
326-
# Infinite numbers of payments are okay, so ignore the
327-
# potential divide by zero.
328-
nper_array[zero] = -(fv[zero] + pv[zero]) / pmt[zero]
329-
330-
nonzero_rate = rate[nonzero]
331-
z = pmt[nonzero] * (1 + nonzero_rate * when[nonzero]) / nonzero_rate
332-
nper_array[nonzero] = (
333-
np.log((-fv[nonzero] + z) / (pv[nonzero] + z))
334-
/ np.log(1 + nonzero_rate)
335-
)
336-
337-
return nper_array
326+
rates = np.atleast_1d(rate).astype(np.float64)
327+
pmts = np.atleast_1d(pmt).astype(np.float64)
328+
pvs = np.atleast_1d(pv).astype(np.float64)
329+
fvs = np.atleast_1d(fv).astype(np.float64)
330+
whens = np.atleast_1d(when).astype(np.float64)
331+
332+
out_shape = _get_output_array_shape(rates, pmts, pvs, fvs, whens)
333+
out = np.empty(out_shape)
334+
_cfinancial.nper(rates, pmts, pvs, fvs, whens, out)
335+
return _ufunc_like(out)
338336

339337

340338
def _value_like(arr, value):

numpy_financial/tests/test_financial.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,17 @@ def uint_dtype():
4545
cashflow_list_strategy,
4646
)
4747

48-
short_scalar_array = npst.arrays(
48+
short_scalar_array_strategy = npst.arrays(
4949
dtype=real_scalar_dtypes,
5050
shape=npst.array_shapes(min_dims=0, max_dims=1, min_side=0, max_side=5),
5151
)
5252

5353

54+
when_strategy = st.sampled_from(
55+
['end', 'begin', 'e', 'b', 0, 1, 'beginning', 'start', 'finish']
56+
)
57+
58+
5459
def assert_decimal_close(actual, expected, tol=Decimal("1e-7")):
5560
# Check if both actual and expected are iterable (like arrays)
5661
if hasattr(actual, "__iter__") and hasattr(expected, "__iter__"):
@@ -280,7 +285,7 @@ def test_npv(self):
280285
rtol=1e-2,
281286
)
282287

283-
@given(rates=short_scalar_array, values=cashflow_array_strategy)
288+
@given(rates=short_scalar_array_strategy, values=cashflow_array_strategy)
284289
@settings(deadline=None)
285290
def test_fuzz(self, rates, values):
286291
npf.npv(rates, values)
@@ -426,6 +431,16 @@ def test_broadcast(self):
426431
npf.nper(0.075, -2000, 0, 100000.0, [0, 1]), [21.5449442, 20.76156441], 4
427432
)
428433

434+
@given(
435+
rates=short_scalar_array_strategy,
436+
payments=short_scalar_array_strategy,
437+
present_values=short_scalar_array_strategy,
438+
future_values=short_scalar_array_strategy,
439+
whens=when_strategy,
440+
)
441+
def test_fuzz(self, rates, payments, present_values, future_values, whens):
442+
npf.nper(rates, payments, present_values, future_values, whens)
443+
429444

430445
class TestPpmt:
431446
def test_float(self):

0 commit comments

Comments
 (0)