Skip to content

Commit 1b7aec3

Browse files
authored
Merge pull request #31 from gaia-dpci/cholesky-addition
Cholesky addition
2 parents f658c16 + ad83c82 commit 1b7aec3

File tree

148 files changed

+3902
-1332
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

148 files changed

+3902
-1332
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ jobs:
3939
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4040
- name: Test with pytest
4141
run: |
42-
pytest
42+
pytest -s

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ __pycache__/
1111
build/
1212
dist/
1313
docs/_build/
14+
htmlcov/
15+
junit/
1416
tests_output_files/
17+
coverage.xml
18+
query_tests.py

MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include LICENCE
22
include requirements.txt
3-
include gaiaxpy/headers/headers_dict.txt
3+
include gaiaxpy/output/ecsv_headers/headers_dict.txt
44
graft gaiaxpy/config
55
global-exclude __pycache__

docs/source/releasenotes.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ Version 1.1.3
1414
-------------
1515
Released on 2022/06/16.
1616

17-
* Fix legend bug in plotter.
18-
* Restrict pandas version, >= 1.0.0.
17+
* Fixed legend bug in plotter.
18+
* Restricted pandas version, >= 1.0.0.
1919

2020
Version 1.1.2
2121
-------------

gaiaxpy/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .calibrator.calibrator import calibrate
2+
from .cholesky.cholesky import get_inverse_covariance_matrix, get_chi2
23
from .converter.converter import convert
3-
from .core import pwl_to_wl, wl_to_pwl, pwl_range, wl_range
4-
from .error_correction import apply_error_correction
4+
from .core.dispersion_function import pwl_to_wl, wl_to_pwl, pwl_range, wl_range
5+
from .error_correction.error_correction import apply_error_correction
56
from .generator.generator import generate
67
from .generator.photometric_system import PhotometricSystem
78
from .plotter.plot_spectra import plot_spectra
89

9-
__version__ = '1.1.4'
10+
__version__ = '1.2.0'

gaiaxpy/calibrator/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +0,0 @@
1-
from . import calibrator
2-
from .calibrator import _calibrate, _create_merge, _create_spectrum
3-
4-
from . import external_instrument_model
5-
from .external_instrument_model import ExternalInstrumentModel

gaiaxpy/calibrator/calibrator.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,27 @@
88
import pandas as pd
99
from configparser import ConfigParser
1010
from pathlib import Path
11+
from tqdm import tqdm
1112
from os.path import join
1213
from .external_instrument_model import ExternalInstrumentModel
13-
from gaiaxpy.config import config_path
14-
from gaiaxpy.core.satellite import BANDS
15-
from gaiaxpy.core import _get_spectra_type, _load_xpmerge_from_csv, \
16-
_load_xpsampling_from_csv, _progress_tracker, \
17-
_validate_arguments, _validate_wl_sampling, satellite
18-
from gaiaxpy.input_reader import InputReader
19-
from gaiaxpy.output import SampledSpectraData
20-
from gaiaxpy.spectrum import _get_covariance_matrix, AbsoluteSampledSpectrum, \
21-
SampledBasisFunctions, XpContinuousSpectrum
14+
from gaiaxpy.config.paths import config_path
15+
from gaiaxpy.core.config import _load_xpmerge_from_csv, _load_xpsampling_from_csv
16+
from gaiaxpy.core.generic_functions import cast_output, _get_spectra_type, \
17+
_validate_arguments, \
18+
_validate_wl_sampling
19+
from gaiaxpy.core.satellite import BANDS, BP_WL, RP_WL
20+
from gaiaxpy.core.generic_variables import pbar_colour, pbar_units
21+
from gaiaxpy.input_reader.input_reader import InputReader
22+
from gaiaxpy.output.sampled_spectra_data import SampledSpectraData
23+
from gaiaxpy.spectrum.absolute_sampled_spectrum import AbsoluteSampledSpectrum
24+
from gaiaxpy.spectrum.sampled_basis_functions import SampledBasisFunctions
25+
from gaiaxpy.spectrum.utils import _get_covariance_matrix
26+
from gaiaxpy.spectrum.xp_continuous_spectrum import XpContinuousSpectrum
2227

2328
config_parser = ConfigParser()
2429
config_parser.read(join(config_path, 'config.ini'))
25-
30+
tqdm.pandas(desc='Processing data', unit=pbar_units['calibrator'], leave=False, \
31+
colour=pbar_colour) # Activate tqdm for pandas
2632

2733
def calibrate(
2834
input_object,
@@ -107,18 +113,16 @@ def _calibrate(
107113
"""
108114
_validate_wl_sampling(sampling)
109115
_validate_arguments(_calibrate.__defaults__[3], output_file, save_file)
110-
parsed_input_data, extension = InputReader(input_object, _calibrate, username, password)._read()
116+
parsed_input_data, extension = InputReader(input_object, _calibrate, \
117+
username, password)._read()
111118
label = 'calibrator'
112-
113-
xp_design_matrices, xp_merge = _generate_xp_matrices_and_merge(label, sampling, bp_model, rp_model)
119+
xp_design_matrices, xp_merge = _generate_xp_matrices_and_merge(label, \
120+
sampling, bp_model, rp_model)
114121
# Create sampled basis functions
115-
spectra_list = _create_spectra(parsed_input_data, truncation, xp_design_matrices, xp_merge)
116-
# Generate output
117-
spectra_df = pd.DataFrame.from_records([spectrum._spectrum_to_dict() for spectrum in spectra_list])
118-
spectra_type = _get_spectra_type(spectra_list)
119-
spectra_df.attrs['data_type'] = spectra_type
120-
positions = spectra_list[0]._get_positions()
122+
spectra_df, positions = _create_spectra(parsed_input_data, truncation, \
123+
xp_design_matrices, xp_merge)
121124
output_data = SampledSpectraData(spectra_df, positions)
125+
output_data.data = cast_output(output_data)
122126
# Save output
123127
Path(output_path).mkdir(parents=True, exist_ok=True)
124128
output_data.save(save_file, output_path, output_file, output_format, extension)
@@ -136,8 +140,8 @@ def _create_merge(xp, sampling):
136140
Returns:
137141
dict: A dictionary containing a BP and an RP array with weights.
138142
"""
139-
wl_high = satellite.BP_WL.high
140-
wl_low = satellite.RP_WL.low
143+
wl_high = BP_WL.high
144+
wl_low = RP_WL.low
141145

142146
if xp == BANDS.bp:
143147
weight = np.array([1.0 if wl < wl_low else 0.0 if wl > wl_high else (
@@ -184,22 +188,15 @@ def _get_file_for_xp(xp, key, bp_model=bp_model, rp_model=rp_model):
184188

185189

186190
def _create_spectra(parsed_spectrum_file, truncation, design_matrices, merge):
187-
"""
188-
Internal wrapper function. Allows _create_spectrum to use the generic
189-
progress tracker.
190-
"""
191-
spectra_list = []
192191
nrows = len(parsed_spectrum_file)
193-
194-
@_progress_tracker
195-
def create_spectrum(row, *args):
196-
truncation, design_matrices, merge = args[:3]
197-
spectrum = _create_spectrum(
198-
row, truncation, design_matrices, merge)
199-
spectra_list.append(spectrum)
200-
for index, row in parsed_spectrum_file.iterrows():
201-
create_spectrum(row, truncation, design_matrices, merge, index, nrows)
202-
return spectra_list
192+
spectra_series = parsed_spectrum_file.progress_apply(lambda row: \
193+
_create_spectrum(row, truncation, design_matrices, merge), axis=1)
194+
positions = spectra_series.iloc[0]._get_positions()
195+
spectra_type = _get_spectra_type(spectra_series.iloc[0])
196+
spectra_series = spectra_series.map(lambda x: x._spectrum_to_dict())
197+
spectra_df = pd.DataFrame(spectra_series.tolist())
198+
spectra_df.attrs['data_type'] = spectra_type
199+
return spectra_df, positions
203200

204201

205202
def _create_spectrum(row, truncation, design_matrix, merge):
@@ -224,9 +221,6 @@ def _create_spectrum(row, truncation, design_matrix, merge):
224221
source_id = row['source_id']
225222
cont_dict = {}
226223
# Split both bands
227-
source_id = row['source_id']
228-
cont_dict = {}
229-
# Split both bands
230224
for band in BANDS:
231225
try:
232226
covariance_matrix = _get_covariance_matrix(row, band)

gaiaxpy/calibrator/external_instrument_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
from scipy import interpolate
10-
from gaiaxpy.file_parser import GenericParser
10+
from gaiaxpy.file_parser.parse_generic import GenericParser
1111

1212

1313
class ExternalInstrumentModel(object):

gaiaxpy/cholesky/__init__.py

Whitespace-only changes.

gaiaxpy/cholesky/cholesky.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pandas as pd
2+
from numpy import diag, dot, identity
3+
from scipy.linalg import cholesky, solve_triangular
4+
from gaiaxpy.core.satellite import BANDS
5+
from gaiaxpy.input_reader.input_reader import InputReader
6+
7+
8+
def __get_dot_product(L_inv):
9+
try:
10+
return dot(L_inv.T, L_inv)
11+
except AttributeError:
12+
return None
13+
14+
15+
def __get_inv_cholesky_decomp_lower(xp_errors, xp_correlation_matrix):
16+
try:
17+
L = cholesky(xp_correlation_matrix, lower=True)
18+
# Invert lower triangular matrix.
19+
L_inv = solve_triangular(L, identity(len(L)), lower=True)
20+
# Matrix of inverse errors.
21+
E_inv = diag(1.0 / xp_errors)
22+
return dot(L_inv, E_inv)
23+
except ValueError:
24+
return None
25+
26+
27+
def get_inverse_covariance_matrix(input_object, band=None):
28+
"""
29+
Compute the inverse covariance matrix.
30+
31+
Args:
32+
input_object (object): Path to the file containing the mean spectra
33+
as downloaded from the archive in their continuous representation,
34+
a list of sources ids (string or long), or a pandas DataFrame.
35+
band (str): Chosen band: 'bp' or 'rp'. If no band is passed, the function
36+
will compute the inverse covariance for both 'bp' and 'rp''
37+
38+
Returns:
39+
DataFrame or ndarray of ndarrays: DataFrame containing the source IDs and
40+
the output inverse covariance matrices for the sources in the input
41+
object if it contains more than one source or no band is passed to
42+
the function.
43+
The function will return a ndarray (of shape (55, 55)) if there is
44+
only one source ID in the input data and a single band is selected.
45+
"""
46+
parsed_input_data, extension = InputReader(input_object, get_inverse_covariance_matrix)._read()
47+
bands_output = []
48+
if band is None:
49+
bands_to_process = BANDS
50+
output_columns = ['source_id', 'bp_inverse_covariance', 'rp_inverse_covariance']
51+
else:
52+
bands_to_process = [band]
53+
output_columns = ['source_id', f'{band}_inverse_covariance']
54+
for b in bands_to_process:
55+
xp_errors = parsed_input_data[f'{b}_coefficient_errors']
56+
xp_correlation_matrix = parsed_input_data[f'{b}_coefficient_correlations']
57+
L_inv_iterable = map(__get_inv_cholesky_decomp_lower, xp_errors, xp_correlation_matrix)
58+
band_output = map(__get_dot_product, L_inv_iterable)
59+
bands_output.append(band_output)
60+
output_list = [parsed_input_data['source_id']]
61+
for element in bands_output:
62+
output_list.append(element)
63+
output_df = pd.DataFrame(zip(*output_list), columns=output_columns)
64+
if len(bands_to_process) == 1 and len(output_df) == 1:
65+
return output_df[f'{band}_inverse_covariance'].iloc[0]
66+
else:
67+
return output_df
68+
69+
def get_chi2(L_inv, residuals):
70+
if L_inv.shape != (55, 55):
71+
raise ValueError('Inverse covariance matrix shape must be (55, 55).')
72+
if residuals.shape != (55,):
73+
raise ValueError('Residuals shape must be (55,).')
74+
x = dot(L_inv.T, residuals)
75+
return dot(x.T, x)

0 commit comments

Comments
 (0)