Skip to content

Commit 0db86ac

Browse files
committed
fix bug in mixuture copula parameter init
1 parent f1c6658 commit 0db86ac

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

starvine/bvcopula/copula/mixture_copula.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, copula_a, wt_a, copula_b, wt_b):
2020
self._name = copula_a.name + '-' + copula_b.name
2121
self._thetaBounds = tuple(list(copula_a.thetaBounds) + list(copula_b.thetaBounds) + [(0.,1.), (0.,1.)])
2222
self._theta0 = tuple(list(copula_a.theta0) + list(copula_b.theta0) + [wt_a, wt_b])
23-
self.fittedParams = list(copula_a.theta0) + list(copula_b.theta0) + [wt_a, wt_b]
23+
self.fittedParams = list(copula_a.fittedParams) + list(copula_b.fittedParams) + [wt_a, wt_b]
2424
self._copula_a = copula_a
2525
self._copula_b = copula_b
2626
self._n_params_a = len(copula_a.thetaBounds)

starvine/bvcopula/tests/test_mixture_copula.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# COPULA IMPORTS
77
from starvine.bvcopula.copula.mixture_copula import MixtureCopula as mc
88
import numpy as np
9-
import matplotlib.pyplot as plt
9+
from starvine.bvcopula import bv_plot
1010
import os
1111
pwd_ = os.path.dirname(os.path.abspath(__file__))
1212
from starvine.bvcopula.copula.gumbel_copula import GumbelCopula
@@ -17,14 +17,17 @@
1717

1818
class TestMixtureCopula(unittest.TestCase):
1919
def setUp(self):
20-
self._mix_copula = mc(GumbelCopula(1), 0.5,
21-
GumbelCopula(2), 0.5)
20+
self._mix_copula = mc(GumbelCopula(2, [2.1]), 0.5,
21+
GumbelCopula(3, [2.7]), 0.5)
2222

2323
def testMixCoplulaPdf(self):
24-
u = np.linspace(1.0e-8, 1.0-1e-8, 50)
25-
v = np.linspace(1.0e-8, 1.0-1e-8, 50)
26-
c_pdf = self._mix_copula.pdf(u, v)
24+
u = np.linspace(6.0e-2, 1.0-6e-2, 50)
25+
v = np.linspace(6.0e-2, 1.0-6e-2, 50)
26+
uu, vv = np.meshgrid(u, v)
27+
c_pdf = self._mix_copula.pdf(uu.flatten(), vv.flatten())
2728
self.assertTrue(np.all(c_pdf >= 0))
29+
# plot mixture pdf
30+
bv_plot.bvContourf(uu.flatten(), vv.flatten(), c_pdf, savefig="mix.png")
2831

2932
def testMixCoplulaCdf(self):
3033
u = np.linspace(1.0e-8, 1.0-1e-8, 50)

0 commit comments

Comments
 (0)