Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtribick committed Sep 25, 2020
1 parent b985fc6 commit f29c406
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 45 deletions.
28 changes: 14 additions & 14 deletions download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from astroquery.utils.tap import Tap
from astroquery.xmatch import XMatch

def yesno(prompt, default=False):
def yesno(prompt: str, default: bool=False) -> bool:
"""Prompt the user for yes/no input."""
if default:
new_prompt = f'{prompt} (Y/n): '
Expand All @@ -54,7 +54,7 @@ def yesno(prompt, default=False):
elif answer == 'n' or answer == 'N':
return False

def proceed_checkfile(filename):
def proceed_checkfile(filename: str) -> bool:
"""Check if a file exists, if so prompt the user if they want to replace it."""
if os.path.exists(filename):
if yesno(f'{filename} already exists, replace?'):
Expand All @@ -64,7 +64,7 @@ def proceed_checkfile(filename):
return False
return True

def download_file(outfile_name, url):
def download_file(outfile_name: str, url: str) -> bool:
"""Download a file using requests."""
if not proceed_checkfile(outfile_name):
return
Expand All @@ -79,7 +79,7 @@ def download_file(outfile_name, url):

# --- GAIA DATA DOWNLOAD ---

def download_gaia_data(colname, xindex_table, outfile_name):
def download_gaia_data(colname: str, xindex_table: str, outfile_name: str) -> None:
"""Query and download Gaia data."""
query = f"""SELECT
x.source_id, x.original_ext_source_id AS {colname},
Expand All @@ -104,7 +104,7 @@ def download_gaia_data(colname, xindex_table, outfile_name):
CONESEARCH_URL = \
'https://www.cosmos.esa.int/documents/29201/1769576/Hipparcos2GaiaDR2coneSearch.zip'

def download_gaia_hip(username):
def download_gaia_hip(username: str) -> None:
"""Download HIP data from the Gaia archive."""
hip_file = os.path.join('gaia', 'gaiadr2_hip-result.csv')
if not proceed_checkfile(hip_file):
Expand Down Expand Up @@ -149,9 +149,9 @@ def download_gaia_hip(username):
finally:
Gaia.delete_user_table('hipgpma')

def get_missing_tyc_ids(tyc_file, ascc_file):
def get_missing_tyc_ids(tyc_file: str, ascc_file: str) -> Table:
"""Finds the ASCC TYC ids that are not present in Gaia cross-match."""
def load_tyc(filename):
def load_tyc(filename: str) -> Table:
with open(filename, 'r') as f:
header = f.readline().split(',')
col_idx = header.index('tyc2_id')
Expand All @@ -172,7 +172,7 @@ def load_tyc(filename):
return Table([tyc1, tyc2, tyc3], names=['TYC1','TYC2','TYC3'], dtype=('i4', 'i4', 'i4'))


def load_ascc(filename):
def load_ascc(filename: str) -> Table:
data = None
with tarfile.open(filename, 'r:gz') as tf:

Expand Down Expand Up @@ -207,7 +207,7 @@ def load_ascc(filename):

return Table([[f"TYC {t['TYC1']}-{t['TYC2']}-{t['TYC3']}" for t in t_missing]], names=['id'])

def download_gaia_tyc(username):
def download_gaia_tyc(username: str) -> None:
"""Download TYC data from the Gaia archive."""

tyc_file = os.path.join('gaia', 'gaiadr2_tyc-result.csv')
Expand Down Expand Up @@ -258,7 +258,7 @@ def download_gaia_tyc(username):
finally:
Gaia.delete_user_table('tyc_missing')

def download_gaia():
def download_gaia() -> None:
"""Download data from the Gaia archive."""
with contextlib.suppress(FileExistsError):
os.mkdir('gaia')
Expand All @@ -283,7 +283,7 @@ def download_gaia():

# --- SAO XMATCH DOWNLOAD ---

def download_xmatch(cat1, cat2, outfile_name):
def download_xmatch(cat1: str, cat2: str, outfile_name: str) -> None:
"""Download a cross-match from VizieR."""
if not proceed_checkfile(outfile_name):
return
Expand All @@ -294,7 +294,7 @@ def download_xmatch(cat1, cat2, outfile_name):

io_ascii.write(result, outfile_name, format='csv')

def download_sao_xmatch():
def download_sao_xmatch() -> None:
"""Download cross-matches to the SAO catalogue."""
with contextlib.suppress(FileExistsError):
os.mkdir('xmatch')
Expand All @@ -305,11 +305,11 @@ def download_sao_xmatch():
]

for cat1, cat2, filename in cross_matches:
print('Downloading '+cat1+'-'+cat2+' crossmatch')
print(f'Downloading {cat1}-{cat2} crossmatch')
download_xmatch(cat1, cat2, os.path.join('xmatch', filename))

# --- VIZIER DOWNLOAD ---
def download_vizier():
def download_vizier() -> None:
"""Download catalogue archive files from VizieR."""
with contextlib.suppress(FileExistsError):
os.mkdir('vizier')
Expand Down
22 changes: 11 additions & 11 deletions make_stardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import astropy.io.ascii as io_ascii
import astropy.units as u

from astropy.table import MaskedColumn, join, unique, vstack
from astropy.table import MaskedColumn, Table, join, unique, vstack
from astropy.units import UnitsWarning

from parse_hip import process_hip
Expand Down Expand Up @@ -60,7 +60,7 @@

CEL_SPECS = parse_spectrum_vec(['OBAFGKM'[i//10]+str(i%10) for i in range(3, 70)])

def load_ubvri():
def load_ubvri() -> Table:
"""Load UBVRI Teff calibration from VizieR archive."""
print('Loading UBVRI calibration')
with tarfile.open(os.path.join('vizier', 'ubvriteff.tar.gz'), 'r:gz') as tf:
Expand All @@ -76,15 +76,15 @@ def load_ubvri():
warnings.simplefilter('ignore', UnitsWarning)
return reader.read(f)

def parse_spectra(data):
def parse_spectra(data: Table) -> Table:
"""Parse the spectral types into the celestia.Sci format."""
print('Parsing spectral types')
data['SpType'] = data['SpType'].filled('')
sptypes = unique(data['SpType',])
sptypes['CelSpec'] = parse_spectrum_vec(sptypes['SpType'])
return join(data, sptypes)

def estimate_magnitudes(data):
def estimate_magnitudes(data: Table) -> None:
"""Estimates magnitudes and color indices from G magnitude and BP-RP.
Formula used is from Evans et al. (2018) "Gaia Data Release 2: Photometric
Expand Down Expand Up @@ -146,7 +146,7 @@ def estimate_magnitudes(data):
data.remove_columns(['Bmag', 'e_Bmag', 'e_Vmag', 'Jmag', 'e_Jmag', 'Hmag', 'e_Hmag',
'Kmag', 'e_Kmag'])

def estimate_temperatures(data):
def estimate_temperatures(data: Table) -> None:
"""Estimate the temperature of stars."""
ubvri_data = load_ubvri()
print('Estimating temperatures from color indices')
Expand Down Expand Up @@ -177,7 +177,7 @@ def estimate_temperatures(data):
data['teff_est'] = teffs / weights
data['teff_est'].unit = u.K

def estimate_spectra(data):
def estimate_spectra(data: Table) -> Table:
"""Estimate the spectral type of stars."""
no_teff = data[data['teff_val'].mask]
# temporarily disable no-member error in pylint, as it cannot see the reduce method
Expand All @@ -199,7 +199,7 @@ def estimate_spectra(data):
data['CelSpec'] = CEL_SPECS[np.digitize(data['teff_val'], TEFF_BINS)]
return data

def merge_all():
def merge_all() -> Table:
"""Merges the HIP and TYC data."""
hip_data = process_hip()
tyc_data = join(process_tyc(),
Expand All @@ -218,7 +218,7 @@ def merge_all():
[0, COS_OBLIQUITY, SIN_OBLIQUITY],
[0, -SIN_OBLIQUITY, COS_OBLIQUITY]])

def process_data():
def process_data() -> Table:
"""Processes the missing data values."""
data = merge_all()
data = data[np.logical_not(data['dist_use'].mask)]
Expand Down Expand Up @@ -263,7 +263,7 @@ def process_data():

return data

def write_starsdat(data, outfile):
def write_starsdat(data: Table, outfile: str) -> None:
"""Write the stars.dat file."""
print('Writing stars.dat')
with open(outfile, 'wb') as f:
Expand All @@ -274,7 +274,7 @@ def write_starsdat(data, outfile):
data['Vmag_abs'], data['CelSpec']):
f.write(fmt.pack(hip, x, y, z, int(round(vmag_abs*256)), celspec))

def write_xindex(data, field, outfile):
def write_xindex(data: Table, field: str, outfile: str) -> None:
"""Write a cross-index file."""
print('Writing '+field+' cross-index')
print(' Extracting cross-index data')
Expand All @@ -288,7 +288,7 @@ def write_xindex(data, field, outfile):
for hip, cat in zip(data['HIP'], data[field]):
f.write(fmt.pack(cat, hip))

def make_stardb():
def make_stardb() -> None:
"""Make the Celestia star database files."""
data = process_data()

Expand Down
16 changes: 8 additions & 8 deletions parse_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
import astropy.units as u

from astropy.coordinates import ICRS, SkyCoord
from astropy.table import join, unique
from astropy.table import Table, join, unique
from astropy.time import Time
from astropy.units import UnitsWarning

def load_gaia_hip():
def load_gaia_hip() -> Table:
"""Load the Gaia DR2 HIP sources."""
print('Loading Gaia DR2 sources for HIP')
col_names = ['source_id', 'hip_id', 'ra', 'dec', 'phot_g_mean_mag', 'bp_rp',
Expand All @@ -50,7 +50,7 @@ def load_gaia_hip():

return gaia

def load_xhip():
def load_xhip() -> Table:
"""Load the XHIP catalogue from the VizieR archive."""
print('Loading XHIP')
with tarfile.open(os.path.join('vizier', 'xhip.tar.gz'), 'r:gz') as tf:
Expand Down Expand Up @@ -107,7 +107,7 @@ def load_xhip():

return join(hip_data, biblio_data, join_type='left', keys='HIP')

def load_sao():
def load_sao() -> Table:
"""Load the SAO-HIP cross match."""
print('Loading SAO-HIP cross match')
data = io_ascii.read(os.path.join('xmatch', 'sao_hip_xmatch.csv'),
Expand All @@ -123,7 +123,7 @@ def load_sao():
data.add_index('HIP')
return data

def compute_distances(hip_data, length_kpc=1.35):
def compute_distances(hip_data: Table, length_kpc: float=1.35) -> None:
"""Compute the distance using an exponentially-decreasing prior.
The method is described in:
Expand Down Expand Up @@ -162,7 +162,7 @@ def compute_distances(hip_data, length_kpc=1.35):
HIP_TIME = Time('J1991.25')
GAIA_TIME = Time('J2015.5')

def update_coordinates(hip_data):
def update_coordinates(hip_data: Table) -> None:
"""Update the coordinates from J1991.25 to J2015.5 to match Gaia."""
print('Updating coordinates to J2015.5')
coords = SkyCoord(frame=ICRS,
Expand All @@ -179,15 +179,15 @@ def update_coordinates(hip_data):
hip_data['dec'] = coords.dec / u.deg
hip_data['dec'].unit = u.deg

def process_xhip():
def process_xhip() -> Table:
"""Processes the XHIP data."""
xhip = load_xhip()
compute_distances(xhip)
update_coordinates(xhip)
xhip.remove_columns(['RAdeg', 'DEdeg', 'Plx', 'e_Plx', 'pmRA', 'pmDE', 'RV', 'Dist', 'e_Dist'])
return xhip

def process_hip():
def process_hip() -> Table:
"""Process the Gaia and HIP data."""
data = join(load_gaia_hip(),
process_xhip(),
Expand Down
27 changes: 16 additions & 11 deletions parse_tyc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,31 @@
import re
import tarfile

from tarfile import TarFile, TarInfo
from typing import Tuple

import numpy as np
import astropy.io.ascii as io_ascii
import astropy.units as u

from astropy.table import MaskedColumn, Table, join, unique, vstack

def parse_tyc_string(data, src_column, dest_column='TYC'):
def parse_tyc_string(data: Table, src_column: str, dest_column: str='TYC') -> None:
"""Parse a TYC string into a synthetic HIP identifier."""
tycs = np.array(np.char.split(data[src_column], '-').tolist()).astype(np.int64)
data[dest_column] = tycs[:, 0] + tycs[:, 1]*10000 + tycs[:, 2]*1000000000
data.remove_column(src_column)

def parse_tyc_cols(data, src_columns=('TYC1', 'TYC2', 'TYC3'), dest_column='TYC'):
def parse_tyc_cols(data: Table,
src_columns: Tuple[str, str, str]=('TYC1', 'TYC2', 'TYC3'),
dest_column: str='TYC') -> None:
"""Convert TYC identifier components into a synthetic HIP identifier."""
data[dest_column] = (data[src_columns[0]]
+ data[src_columns[1]]*10000
+ data[src_columns[2]]*1000000000)
data.remove_columns(src_columns)

def load_gaia_tyc():
def load_gaia_tyc() -> Table:
"""Load the Gaia DR2 TYC2 sources."""
print('Loading Gaia DR2 sources for TYC2')
col_names = ['source_id', 'tyc2_id', 'ra', 'dec', 'phot_g_mean_mag', 'bp_rp',
Expand All @@ -64,7 +69,7 @@ def load_gaia_tyc():

return gaia

def load_tyc_spec():
def load_tyc_spec() -> Table:
"""Load the TYC2 spectral type catalogue."""
print('Loading TYC2 spectral types')
with tarfile.open(os.path.join('vizier', 'tyc2spec.tar.gz')) as tf:
Expand All @@ -81,9 +86,9 @@ def load_tyc_spec():
data.add_index('TYC')
return data

def load_ascc():
def load_ascc() -> Table:
"""Load ASCC from VizieR archive."""
def load_section(tf, info):
def load_section(tf: TarFile, info: TarInfo) -> Table:
with tf.extractfile('./ReadMe') as readme:
col_names = ['Bmag', 'Vmag', 'e_Bmag', 'e_Vmag', 'd3', 'TYC1', 'TYC2', 'TYC3', 'HD',
'Jmag', 'e_Jmag', 'Hmag', 'e_Hmag', 'Kmag', 'e_Kmag']
Expand All @@ -107,7 +112,7 @@ def load_section(tf, info):

return section

def is_data(info):
def is_data(info: TarInfo) -> bool:
sections = os.path.split(info.name)
return (len(sections) == 2 and
sections[0] == '.' and
Expand All @@ -129,7 +134,7 @@ def is_data(info):
data.add_index('TYC')
return data

def load_tyc_teff():
def load_tyc_teff() -> Table:
"""Load the Tycho-2 effective temperatures."""
print('Loading TYC2 effective temperatures')
with tarfile.open(os.path.join('vizier', 'tyc2teff.tar.gz'), 'r:gz') as tf:
Expand Down Expand Up @@ -195,7 +200,7 @@ def load_tyc_teff():
data.add_index('TYC')
return unique(data, keys=['TYC'])

def load_sao():
def load_sao() -> Table:
"""Load the SAO-TYC2 cross match."""
print('Loading SAO-TYC2 cross match')
data = io_ascii.read(os.path.join('xmatch', 'sao_tyc2_xmatch.csv'),
Expand All @@ -213,7 +218,7 @@ def load_sao():
data.add_index('TYC')
return data

def merge_tables():
def merge_tables() -> Table:
"""Merges the tables."""
data = join(load_gaia_tyc(), load_tyc_spec(), keys=['TYC'], join_type='left')
data = join(data, load_ascc(), keys=['TYC'], join_type='left', metadata_conflicts='silent')
Expand All @@ -232,7 +237,7 @@ def merge_tables():
data = join(data, load_sao(), keys=['TYC'], join_type='left')
return data

def process_tyc():
def process_tyc() -> Table:
"""Processes the TYC data."""
data = merge_tables()
data.rename_column('r_est', 'dist_use')
Expand Down
2 changes: 1 addition & 1 deletion spparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def visit_wdstar(self, node, children):
VISITOR = SpecVisitor()
MULTISEPARATOR = re.compile(r'\+\ *(?:\.{2,}|(?:\(?(?:sd|d|g|c|k|h|m|g|He)?[OBAFGKM]|W[DNOCR]|wd))')

def parse_spectrum(sptype):
def parse_spectrum(sptype: str) -> int:
"""Parse a spectral type string into a Celestia spectral type."""

# resolve ambiguity in grammar: B 0-Ia could be interpreted as (B 0-) Ia or B (0-Ia)
Expand Down

0 comments on commit f29c406

Please sign in to comment.