From f29c406971be0844ec361272bb49bec7e2b3da68 Mon Sep 17 00:00:00 2001 From: Andrew Tribick Date: Fri, 25 Sep 2020 21:52:33 +0200 Subject: [PATCH] Add type hints --- download_data.py | 28 ++++++++++++++-------------- make_stardb.py | 22 +++++++++++----------- parse_hip.py | 16 ++++++++-------- parse_tyc.py | 27 ++++++++++++++++----------- spparse.py | 2 +- 5 files changed, 50 insertions(+), 45 deletions(-) diff --git a/download_data.py b/download_data.py index 30e7b62..e75109d 100755 --- a/download_data.py +++ b/download_data.py @@ -38,7 +38,7 @@ from astroquery.utils.tap import Tap from astroquery.xmatch import XMatch -def yesno(prompt, default=False): +def yesno(prompt: str, default: bool=False) -> bool: """Prompt the user for yes/no input.""" if default: new_prompt = f'{prompt} (Y/n): ' @@ -54,7 +54,7 @@ def yesno(prompt, default=False): elif answer == 'n' or answer == 'N': return False -def proceed_checkfile(filename): +def proceed_checkfile(filename: str) -> bool: """Check if a file exists, if so prompt the user if they want to replace it.""" if os.path.exists(filename): if yesno(f'{filename} already exists, replace?'): @@ -64,7 +64,7 @@ def proceed_checkfile(filename): return False return True -def download_file(outfile_name, url): +def download_file(outfile_name: str, url: str) -> bool: """Download a file using requests.""" if not proceed_checkfile(outfile_name): return @@ -79,7 +79,7 @@ def download_file(outfile_name, url): # --- GAIA DATA DOWNLOAD --- -def download_gaia_data(colname, xindex_table, outfile_name): +def download_gaia_data(colname: str, xindex_table: str, outfile_name: str) -> None: """Query and download Gaia data.""" query = f"""SELECT x.source_id, x.original_ext_source_id AS {colname}, @@ -104,7 +104,7 @@ def download_gaia_data(colname, xindex_table, outfile_name): CONESEARCH_URL = \ 'https://www.cosmos.esa.int/documents/29201/1769576/Hipparcos2GaiaDR2coneSearch.zip' -def download_gaia_hip(username): +def download_gaia_hip(username: str) -> None: """Download HIP data from the Gaia archive.""" hip_file = os.path.join('gaia', 'gaiadr2_hip-result.csv') if not proceed_checkfile(hip_file): @@ -149,9 +149,9 @@ def download_gaia_hip(username): finally: Gaia.delete_user_table('hipgpma') -def get_missing_tyc_ids(tyc_file, ascc_file): +def get_missing_tyc_ids(tyc_file: str, ascc_file: str) -> Table: """Finds the ASCC TYC ids that are not present in Gaia cross-match.""" - def load_tyc(filename): + def load_tyc(filename: str) -> Table: with open(filename, 'r') as f: header = f.readline().split(',') col_idx = header.index('tyc2_id') @@ -172,7 +172,7 @@ def load_tyc(filename): return Table([tyc1, tyc2, tyc3], names=['TYC1','TYC2','TYC3'], dtype=('i4', 'i4', 'i4')) - def load_ascc(filename): + def load_ascc(filename: str) -> Table: data = None with tarfile.open(filename, 'r:gz') as tf: @@ -207,7 +207,7 @@ def load_ascc(filename): return Table([[f"TYC {t['TYC1']}-{t['TYC2']}-{t['TYC3']}" for t in t_missing]], names=['id']) -def download_gaia_tyc(username): +def download_gaia_tyc(username: str) -> None: """Download TYC data from the Gaia archive.""" tyc_file = os.path.join('gaia', 'gaiadr2_tyc-result.csv') @@ -258,7 +258,7 @@ def download_gaia_tyc(username): finally: Gaia.delete_user_table('tyc_missing') -def download_gaia(): +def download_gaia() -> None: """Download data from the Gaia archive.""" with contextlib.suppress(FileExistsError): os.mkdir('gaia') @@ -283,7 +283,7 @@ def download_gaia(): # --- SAO XMATCH DOWNLOAD --- -def download_xmatch(cat1, cat2, outfile_name): +def download_xmatch(cat1: str, cat2: str, outfile_name: str) -> None: """Download a cross-match from VizieR.""" if not proceed_checkfile(outfile_name): return @@ -294,7 +294,7 @@ def download_xmatch(cat1, cat2, outfile_name): io_ascii.write(result, outfile_name, format='csv') -def download_sao_xmatch(): +def download_sao_xmatch() -> None: """Download cross-matches to the SAO catalogue.""" with contextlib.suppress(FileExistsError): os.mkdir('xmatch') @@ -305,11 +305,11 @@ def download_sao_xmatch(): ] for cat1, cat2, filename in cross_matches: - print('Downloading '+cat1+'-'+cat2+' crossmatch') + print(f'Downloading {cat1}-{cat2} crossmatch') download_xmatch(cat1, cat2, os.path.join('xmatch', filename)) # --- VIZIER DOWNLOAD --- -def download_vizier(): +def download_vizier() -> None: """Download catalogue archive files from VizieR.""" with contextlib.suppress(FileExistsError): os.mkdir('vizier') diff --git a/make_stardb.py b/make_stardb.py index db8a530..210c3d2 100755 --- a/make_stardb.py +++ b/make_stardb.py @@ -30,7 +30,7 @@ import astropy.io.ascii as io_ascii import astropy.units as u -from astropy.table import MaskedColumn, join, unique, vstack +from astropy.table import MaskedColumn, Table, join, unique, vstack from astropy.units import UnitsWarning from parse_hip import process_hip @@ -60,7 +60,7 @@ CEL_SPECS = parse_spectrum_vec(['OBAFGKM'[i//10]+str(i%10) for i in range(3, 70)]) -def load_ubvri(): +def load_ubvri() -> Table: """Load UBVRI Teff calibration from VizieR archive.""" print('Loading UBVRI calibration') with tarfile.open(os.path.join('vizier', 'ubvriteff.tar.gz'), 'r:gz') as tf: @@ -76,7 +76,7 @@ def load_ubvri(): warnings.simplefilter('ignore', UnitsWarning) return reader.read(f) -def parse_spectra(data): +def parse_spectra(data: Table) -> Table: """Parse the spectral types into the celestia.Sci format.""" print('Parsing spectral types') data['SpType'] = data['SpType'].filled('') @@ -84,7 +84,7 @@ def parse_spectra(data): sptypes['CelSpec'] = parse_spectrum_vec(sptypes['SpType']) return join(data, sptypes) -def estimate_magnitudes(data): +def estimate_magnitudes(data: Table) -> None: """Estimates magnitudes and color indices from G magnitude and BP-RP. Formula used is from Evans et al. (2018) "Gaia Data Release 2: Photometric @@ -146,7 +146,7 @@ def estimate_magnitudes(data): data.remove_columns(['Bmag', 'e_Bmag', 'e_Vmag', 'Jmag', 'e_Jmag', 'Hmag', 'e_Hmag', 'Kmag', 'e_Kmag']) -def estimate_temperatures(data): +def estimate_temperatures(data: Table) -> None: """Estimate the temperature of stars.""" ubvri_data = load_ubvri() print('Estimating temperatures from color indices') @@ -177,7 +177,7 @@ def estimate_temperatures(data): data['teff_est'] = teffs / weights data['teff_est'].unit = u.K -def estimate_spectra(data): +def estimate_spectra(data: Table) -> Table: """Estimate the spectral type of stars.""" no_teff = data[data['teff_val'].mask] # temporarily disable no-member error in pylint, as it cannot see the reduce method @@ -199,7 +199,7 @@ def estimate_spectra(data): data['CelSpec'] = CEL_SPECS[np.digitize(data['teff_val'], TEFF_BINS)] return data -def merge_all(): +def merge_all() -> Table: """Merges the HIP and TYC data.""" hip_data = process_hip() tyc_data = join(process_tyc(), @@ -218,7 +218,7 @@ def merge_all(): [0, COS_OBLIQUITY, SIN_OBLIQUITY], [0, -SIN_OBLIQUITY, COS_OBLIQUITY]]) -def process_data(): +def process_data() -> Table: """Processes the missing data values.""" data = merge_all() data = data[np.logical_not(data['dist_use'].mask)] @@ -263,7 +263,7 @@ def process_data(): return data -def write_starsdat(data, outfile): +def write_starsdat(data: Table, outfile: str) -> None: """Write the stars.dat file.""" print('Writing stars.dat') with open(outfile, 'wb') as f: @@ -274,7 +274,7 @@ def write_starsdat(data, outfile): data['Vmag_abs'], data['CelSpec']): f.write(fmt.pack(hip, x, y, z, int(round(vmag_abs*256)), celspec)) -def write_xindex(data, field, outfile): +def write_xindex(data: Table, field: str, outfile: str) -> None: """Write a cross-index file.""" print('Writing '+field+' cross-index') print(' Extracting cross-index data') @@ -288,7 +288,7 @@ def write_xindex(data, field, outfile): for hip, cat in zip(data['HIP'], data[field]): f.write(fmt.pack(cat, hip)) -def make_stardb(): +def make_stardb() -> None: """Make the Celestia star database files.""" data = process_data() diff --git a/parse_hip.py b/parse_hip.py index e8f5352..ddf7edf 100644 --- a/parse_hip.py +++ b/parse_hip.py @@ -27,11 +27,11 @@ import astropy.units as u from astropy.coordinates import ICRS, SkyCoord -from astropy.table import join, unique +from astropy.table import Table, join, unique from astropy.time import Time from astropy.units import UnitsWarning -def load_gaia_hip(): +def load_gaia_hip() -> Table: """Load the Gaia DR2 HIP sources.""" print('Loading Gaia DR2 sources for HIP') col_names = ['source_id', 'hip_id', 'ra', 'dec', 'phot_g_mean_mag', 'bp_rp', @@ -50,7 +50,7 @@ def load_gaia_hip(): return gaia -def load_xhip(): +def load_xhip() -> Table: """Load the XHIP catalogue from the VizieR archive.""" print('Loading XHIP') with tarfile.open(os.path.join('vizier', 'xhip.tar.gz'), 'r:gz') as tf: @@ -107,7 +107,7 @@ def load_xhip(): return join(hip_data, biblio_data, join_type='left', keys='HIP') -def load_sao(): +def load_sao() -> Table: """Load the SAO-HIP cross match.""" print('Loading SAO-HIP cross match') data = io_ascii.read(os.path.join('xmatch', 'sao_hip_xmatch.csv'), @@ -123,7 +123,7 @@ def load_sao(): data.add_index('HIP') return data -def compute_distances(hip_data, length_kpc=1.35): +def compute_distances(hip_data: Table, length_kpc: float=1.35) -> None: """Compute the distance using an exponentially-decreasing prior. The method is described in: @@ -162,7 +162,7 @@ def compute_distances(hip_data, length_kpc=1.35): HIP_TIME = Time('J1991.25') GAIA_TIME = Time('J2015.5') -def update_coordinates(hip_data): +def update_coordinates(hip_data: Table) -> None: """Update the coordinates from J1991.25 to J2015.5 to match Gaia.""" print('Updating coordinates to J2015.5') coords = SkyCoord(frame=ICRS, @@ -179,7 +179,7 @@ def update_coordinates(hip_data): hip_data['dec'] = coords.dec / u.deg hip_data['dec'].unit = u.deg -def process_xhip(): +def process_xhip() -> Table: """Processes the XHIP data.""" xhip = load_xhip() compute_distances(xhip) @@ -187,7 +187,7 @@ def process_xhip(): xhip.remove_columns(['RAdeg', 'DEdeg', 'Plx', 'e_Plx', 'pmRA', 'pmDE', 'RV', 'Dist', 'e_Dist']) return xhip -def process_hip(): +def process_hip() -> Table: """Process the Gaia and HIP data.""" data = join(load_gaia_hip(), process_xhip(), diff --git a/parse_tyc.py b/parse_tyc.py index 1351ad6..70baa04 100644 --- a/parse_tyc.py +++ b/parse_tyc.py @@ -22,26 +22,31 @@ import re import tarfile +from tarfile import TarFile, TarInfo +from typing import Tuple + import numpy as np import astropy.io.ascii as io_ascii import astropy.units as u from astropy.table import MaskedColumn, Table, join, unique, vstack -def parse_tyc_string(data, src_column, dest_column='TYC'): +def parse_tyc_string(data: Table, src_column: str, dest_column: str='TYC') -> None: """Parse a TYC string into a synthetic HIP identifier.""" tycs = np.array(np.char.split(data[src_column], '-').tolist()).astype(np.int64) data[dest_column] = tycs[:, 0] + tycs[:, 1]*10000 + tycs[:, 2]*1000000000 data.remove_column(src_column) -def parse_tyc_cols(data, src_columns=('TYC1', 'TYC2', 'TYC3'), dest_column='TYC'): +def parse_tyc_cols(data: Table, + src_columns: Tuple[str, str, str]=('TYC1', 'TYC2', 'TYC3'), + dest_column: str='TYC') -> None: """Convert TYC identifier components into a synthetic HIP identifier.""" data[dest_column] = (data[src_columns[0]] + data[src_columns[1]]*10000 + data[src_columns[2]]*1000000000) data.remove_columns(src_columns) -def load_gaia_tyc(): +def load_gaia_tyc() -> Table: """Load the Gaia DR2 TYC2 sources.""" print('Loading Gaia DR2 sources for TYC2') col_names = ['source_id', 'tyc2_id', 'ra', 'dec', 'phot_g_mean_mag', 'bp_rp', @@ -64,7 +69,7 @@ def load_gaia_tyc(): return gaia -def load_tyc_spec(): +def load_tyc_spec() -> Table: """Load the TYC2 spectral type catalogue.""" print('Loading TYC2 spectral types') with tarfile.open(os.path.join('vizier', 'tyc2spec.tar.gz')) as tf: @@ -81,9 +86,9 @@ def load_tyc_spec(): data.add_index('TYC') return data -def load_ascc(): +def load_ascc() -> Table: """Load ASCC from VizieR archive.""" - def load_section(tf, info): + def load_section(tf: TarFile, info: TarInfo) -> Table: with tf.extractfile('./ReadMe') as readme: col_names = ['Bmag', 'Vmag', 'e_Bmag', 'e_Vmag', 'd3', 'TYC1', 'TYC2', 'TYC3', 'HD', 'Jmag', 'e_Jmag', 'Hmag', 'e_Hmag', 'Kmag', 'e_Kmag'] @@ -107,7 +112,7 @@ def load_section(tf, info): return section - def is_data(info): + def is_data(info: TarInfo) -> bool: sections = os.path.split(info.name) return (len(sections) == 2 and sections[0] == '.' and @@ -129,7 +134,7 @@ def is_data(info): data.add_index('TYC') return data -def load_tyc_teff(): +def load_tyc_teff() -> Table: """Load the Tycho-2 effective temperatures.""" print('Loading TYC2 effective temperatures') with tarfile.open(os.path.join('vizier', 'tyc2teff.tar.gz'), 'r:gz') as tf: @@ -195,7 +200,7 @@ def load_tyc_teff(): data.add_index('TYC') return unique(data, keys=['TYC']) -def load_sao(): +def load_sao() -> Table: """Load the SAO-TYC2 cross match.""" print('Loading SAO-TYC2 cross match') data = io_ascii.read(os.path.join('xmatch', 'sao_tyc2_xmatch.csv'), @@ -213,7 +218,7 @@ def load_sao(): data.add_index('TYC') return data -def merge_tables(): +def merge_tables() -> Table: """Merges the tables.""" data = join(load_gaia_tyc(), load_tyc_spec(), keys=['TYC'], join_type='left') data = join(data, load_ascc(), keys=['TYC'], join_type='left', metadata_conflicts='silent') @@ -232,7 +237,7 @@ def merge_tables(): data = join(data, load_sao(), keys=['TYC'], join_type='left') return data -def process_tyc(): +def process_tyc() -> Table: """Processes the TYC data.""" data = merge_tables() data.rename_column('r_est', 'dist_use') diff --git a/spparse.py b/spparse.py index 2503673..f5e8324 100644 --- a/spparse.py +++ b/spparse.py @@ -380,7 +380,7 @@ def visit_wdstar(self, node, children): VISITOR = SpecVisitor() MULTISEPARATOR = re.compile(r'\+\ *(?:\.{2,}|(?:\(?(?:sd|d|g|c|k|h|m|g|He)?[OBAFGKM]|W[DNOCR]|wd))') -def parse_spectrum(sptype): +def parse_spectrum(sptype: str) -> int: """Parse a spectral type string into a Celestia spectral type.""" # resolve ambiguity in grammar: B 0-Ia could be interpreted as (B 0-) Ia or B (0-Ia)