Skip to content

Commit

Permalink
Merge pull request #82 from sblunt/driver_read_input_fix
Browse files Browse the repository at this point in the history
Driver & read_input fix
  • Loading branch information
Sarah Blunt authored Dec 4, 2018
2 parents ba4396c + 7bac1a6 commit 68567f3
Show file tree
Hide file tree
Showing 11 changed files with 285 additions and 111 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ before_install:

install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION pip numpy scipy cython matplotlib
- pip install pytest pytest-cov coverage coveralls corner
- pip install pytest pytest-cov coverage coveralls corner deprecation
- pip install . --upgrade
- python setup.py build_ext -i

Expand Down
8 changes: 4 additions & 4 deletions docs/formatting_inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Formatting Input
++++++++++++++++

``orbitize.read_input.read_formatted_file()`` handles reading
in data. Check out the documentation for this method
`here <read_input.html>`_ for details about units, acceptable formats,
and more.
``orbitize.read_input.read_file()`` handles reading
in data. Check out the documentation for this method
`here <read_input.html>`_ for details about units, acceptable formats,
and more.
28 changes: 19 additions & 9 deletions orbitize/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import orbitize.sampler

"""
This module reads input and constructs ``orbitize`` objects
This module reads input and constructs ``orbitize`` objects
in a standardized way.
"""

Expand All @@ -12,32 +12,42 @@ class Driver(object):
Runs through ``orbitize`` methods in a standardized way.
Args:
filename (str): relative path to data file. See ``orbitize.read_input``
sampler_str (str): algorithm to use for orbit computation. "MCMC" for
input_data: Either a relative path to data file or astropy.table.Table object
in the orbitize format. See ``orbitize.read_input``
sampler_str (str): algorithm to use for orbit computation. "MCMC" for
Markov Chain Monte Carlo, "OFTI" for Orbits for the Impatient
num_secondary_bodies (int): number of secondary bodies in the system.
num_secondary_bodies (int): number of secondary bodies in the system.
Should be at least 1.
system_mass (float): mean total mass of the system [M_sol]
plx (float): mean parallax of the system [mas]
mass_err (float, optional): uncertainty on ``system_mass`` [M_sol]
plx_err (float, optional): uncertainty on ``plx`` [mas]
lnlike (str, optional): name of function in ``orbitize.lnlike`` that will
be used to compute likelihood. (default="chi2_lnlike")
mcmc_kwargs (dict, optional): ``num_temps``, ``num_walkers``, and ``num_threads``
mcmc_kwargs (dict, optional): ``num_temps``, ``num_walkers``, and ``num_threads``
kwargs for ``orbitize.sampler.MCMC``
Written: Sarah Blunt, 2018
"""
def __init__(self, filename, sampler_str,
num_secondary_bodies, system_mass, plx,
def __init__(self, input_data, sampler_str,
num_secondary_bodies, system_mass, plx,
mass_err=0, plx_err=0, lnlike='chi2_lnlike', mcmc_kwargs=None):

# Read in data
data_table = orbitize.read_input.read_formatted_file(filename)
# Try to interpret input as a filename first
try:
data_table = orbitize.read_input.read_file(input_data)
except:
try:
# Check if input might be an orbitize style astropy.table.Table
if 'quant_type' in input_data.columns:
data_table = input_data.copy()
except:
raise Exception('Invalid value of input_data for Driver')

# Initialize System object which stores data & sets priors
self.system = orbitize.system.System(
num_secondary_bodies, data_table, system_mass,
num_secondary_bodies, data_table, system_mass,
plx, mass_err=mass_err, plx_err=plx_err
)

Expand Down
190 changes: 110 additions & 80 deletions orbitize/read_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,34 @@
Module to read user input from files and create standardized input for orbitize
"""

import deprecation
import numpy as np
import orbitize
from astropy.table import Table
from astropy.io.ascii import read, write

def read_formatted_file(filename):
""" Reads data from any file
def read_file(filename):
""" Reads data from any file for use in orbitize
readable by ``astropy.io.ascii.read()``, including csv format.
See the `astropy docs <http://docs.astropy.org/en/stable/io/ascii/index.html#id1>`_.
Here is an example of an orbitize-readable .csv input file::
There are two ways to provide input data to orbitize.
The first way is to provide astrometric measurements, shown with the following example.
Example of an orbitize-readable .csv input file::
epoch,object,raoff,raoff_err,decoff,decoff_err,sep,sep_err,pa,pa_err,rv,rv_err
1234,1,0.010,0.005,0.50,0.05,,,,,,
1235,1,,,,,1.0,0.005,89.0,0.1,,
1236,1,,,,,1.0,0.005,89.3,0.3,,
1237,0,,,,,,,,,10,0.1
Each row must have ``epoch`` (in MJD=JD-2400000.5) and ``object``.
Objects are numbered with integers, where the primary/central object is ``0``.
If you have, for example, one RV measurement of a star and three astrometric
measurements of an orbiting planet, you should put ``0`` in the ``object`` column
for the RV point, and ``1`` in the columns for the astrometric measurements.
Each row must have ``epoch`` (in MJD=JD-2400000.5) and ``object``.
Objects are numbered with integers, where the primary/central object is ``0``.
If you have, for example, one RV measurement of a star and three astrometric
measurements of an orbiting planet, you should put ``0`` in the ``object`` column
for the RV point, and ``1`` in the columns for the astrometric measurements.
Each line must also have at least one of the following sets of valid measurements:
Expand All @@ -34,15 +40,23 @@ def read_formatted_file(filename):
.. Note:: Columns with no data can be omitted (e.g. if only separation and PA
are given, the raoff, deoff, and rv columns can be excluded).
If more than one valid set is given (e.g. RV measurement and astrometric measurement
taken at the same epoch), ``read_formatted_file()`` will generate a separate output
row for each valid set.
If more than one valid set is given (e.g. RV measurement and astrometric measurement
taken at the same epoch), ``read_file()`` will generate a separate output row for
each valid set.
.. Warning:: For now, ``orbitize`` only accepts astrometric measurements for one
secondary body. In a future release, it will also handle astrometric measurements for
multiple secondaries, RV measurements of the primary and secondar(ies), and astrometric
measurements of the primary. Stay tuned!
Alternatively, you can also supply a data file with the columns already corresponding to
the orbitize format (see the example in description of what this method returns). This may
be useful if you are wanting to use the output of the `write_orbitize_input` method.
.. Note:: When providing data with columns in the orbitize format, there should be no
empty cells. As in the example below, when quant2 is not applicable, the cell should
contain nan.
Args:
filename (str): Input file name
Expand All @@ -65,7 +79,6 @@ def read_formatted_file(filename):
Written: Henry Ngo, 2018
"""

# initialize output table
output_table = Table(names=('epoch','object','quant1','quant1_err','quant2','quant2_err','quant_type'),
dtype=(float,int,float,float,float,float,'S5'))
Expand All @@ -74,9 +87,12 @@ def read_formatted_file(filename):
input_table = read(filename)
num_measurements = len(input_table)

# Decide if input was given in the orbitize style
orbitize_style = 'quant_type' in input_table.columns

# validate input
# if input_table is Masked, then figure out which entries are masked
# otherwise, just check that we have the required columns
# otherwise, just check that we have the required columns based on orbitize_style flag
if input_table.masked:
if 'epoch' in input_table.columns:
have_epoch = ~input_table['epoch'].mask
Expand All @@ -90,56 +106,59 @@ def read_formatted_file(filename):
raise Exception("Invalid input format: missing some object entries")
else:
raise Exception("Input table MUST have object id!")
if 'raoff' in input_table.columns:
have_ra = ~input_table['raoff'].mask
else:
have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'decoff' in input_table.columns:
have_dec = ~input_table['decoff'].mask
else:
have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'sep' in input_table.columns:
have_sep = ~input_table['sep'].mask
else:
have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'pa' in input_table.columns:
have_pa = ~input_table['pa'].mask
else:
have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'rv' in input_table.columns:
have_rv = ~input_table['rv'].mask
else:
have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False
if orbitize_style: # proper orbitize style should NEVER have masked entries (nan required)
raise Exception("Input table in orbitize style may NOT have empty cells")
else: # Check for these things when not orbitize style
if 'raoff' in input_table.columns:
have_ra = ~input_table['raoff'].mask
else:
have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'decoff' in input_table.columns:
have_dec = ~input_table['decoff'].mask
else:
have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'sep' in input_table.columns:
have_sep = ~input_table['sep'].mask
else:
have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'pa' in input_table.columns:
have_pa = ~input_table['pa'].mask
else:
have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'rv' in input_table.columns:
have_rv = ~input_table['rv'].mask
else:
have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False
else: # no masked entries, just check for required columns
if 'epoch' not in input_table.columns:
raise Exception("Input table MUST have epoch!")
if 'object' not in input_table.columns:
raise Exception("Input table MUST have object id!")
if 'raoff' in input_table.columns:
have_ra = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'decoff' in input_table.columns:
have_dec = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'sep' in input_table.columns:
have_sep = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'pa' in input_table.columns:
have_pa = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'rv' in input_table.columns:
have_rv = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False
if not orbitize_style: # Set these flags only when not already in orbitize style
if 'raoff' in input_table.columns:
have_ra = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'decoff' in input_table.columns:
have_dec = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'sep' in input_table.columns:
have_sep = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'pa' in input_table.columns:
have_pa = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False
if 'rv' in input_table.columns:
have_rv = np.ones(num_measurements, dtype=bool) # ones are False
else:
have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False

# loop through each row and format table
index=0
for row in input_table:

# check epoch format and put in MJD
if row['epoch'] > 2400000.5: # assume this is in JD
MJD = row['epoch'] - 2400000.5
Expand All @@ -151,24 +170,46 @@ def read_formatted_file(filename):
raise Exception("Invalid object ID. Object IDs must be integers.")

# determine input quantity type (RA/DEC, SEP/PA, or RV)
if have_ra[index] and have_dec[index]:
output_table.add_row([MJD, row['object'], row['raoff'], row['raoff_err'], row['decoff'], row['decoff_err'], "radec"])
elif have_sep[index] and have_pa[index]:
output_table.add_row([MJD, row['object'], row['sep'], row['sep_err'], row['pa'], row['pa_err'], "seppa"])
if have_rv[index]:
output_table.add_row([MJD, row['object'], row['rv'], row['rv_err'], None, None, "rv"])
if orbitize_style:
if row['quant_type'] == 'rv': # special format for rv rows
output_table.add_row([MJD, row['object'], row['quant1'], row['quant1_err'], None, None, row['quant_type']])
elif row['quant_type'] == 'radec' or row['quant_type'] == 'seppa': # other allowed formats
output_table.add_row([MJD, row['object'], row['quant1'], row['quant1_err'], row['quant2'], row['quant2_err'], row['quant_type']])
else: # catch wrong formats
raise Exception("Invalid 'quant_type'. Valid values are 'radec', 'seppa' or 'rv'")
else: # When not in orbitize style
if have_ra[index] and have_dec[index]:
output_table.add_row([MJD, row['object'], row['raoff'], row['raoff_err'], row['decoff'], row['decoff_err'], "radec"])
elif have_sep[index] and have_pa[index]:
output_table.add_row([MJD, row['object'], row['sep'], row['sep_err'], row['pa'], row['pa_err'], "seppa"])
if have_rv[index]:
output_table.add_row([MJD, row['object'], row['rv'], row['rv_err'], None, None, "rv"])

index=index+1

return output_table

@deprecation.deprecated(deprecated_in="1.0.2", removed_in="2.0",
current_version=orbitize.__version__,
details="Use read_file() instead. v1.0.2 replaces read_formatted_file and read_orbitize_input with read_file(). For now, this will be a wrapper for read_file and will be removed in the v2.0 release.")
def read_formatted_file(filename):
"""
Version 1.0.2 replaces this function with `read_file`.
Currently exists as a wrapper for `read_file` and will be removed in v2.0
Written: Henry Ngo, 2018
"""

return read_file(filename)

def write_orbitize_input(table,output_filename,file_type='csv'):
""" Writes orbitize-readable input as an ASCII file
Args:
table (astropy.Table): Table containing orbitize-readable input for given
object, as generated by the read functions in this module.
output_filename (str): csv file to write to
file_type (str): Any valid write format for astropy.io.ascii. See the
file_type (str): Any valid write format for astropy.io.ascii. See the
`astropy docs <http://docs.astropy.org/en/stable/io/ascii/index.html#id1>`_.
Defaults to csv.
Expand All @@ -185,25 +226,14 @@ def write_orbitize_input(table,output_filename,file_type='csv'):
# write file
write(table,output=output_filename,format=file_type)

@deprecation.deprecated(deprecated_in="1.0.2", removed_in="2.0",
current_version=orbitize.__version__,
details="Use read_file() instead. v1.0.2 replaces read_orbitize_input and read_formatted_file with read_file(). For now, this will be a wrapper for read_file and will be removed in the v2.0 release.")
def read_orbitize_input(filename):
""" Reads orbitize-readable input from a correctly formatted ASCII file
Args:
filename (str): Name of file to read. It should have columns
indicated in the table below.
Returns:
astropy.Table: Table containing orbitize-readable input for given
object. Columns returned are::
epoch, object, quant1, quant1_err, quant2, quant2_err, quant_type
where ``quant_type`` is one of "radec", "seppa", or "rv".
If ``quant_type`` is "radec" or "seppa", the units of quant are mas and degrees,
if ``quant_type`` is "rv", the units of quant are km/s
"""
Version 1.0.2 replaces this function with `read_file`.
Currently exists as a wrapper for `read_file` and will be removed in v2.0
(written) Henry Ngo, 2018
Written: Henry Ngo, 2018
"""
output_table=read(filename)
return output_table
return read_file(filename)
8 changes: 3 additions & 5 deletions orbitize/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ class System(object):
Args:
num_secondary_bodies (int): number of secondary bodies in the system.
Should be at least 1.
data_table (astropy.table.Table): output from either
``orbitize.read_input.read_formatted_file()`` or
``orbitize.read_input.read_orbitize_input()``
data_table (astropy.table.Table): output from ``orbitize.read_input.read_file()``
system_mass (float): mean total mass of the system, in M_sol
plx (float): mean parallax of the system, in mas
mass_err (float, optional): uncertainty on ``system_mass``, in M_sol
Expand All @@ -33,7 +31,7 @@ class System(object):
argument of periastron 1, position angle of nodes 1,
epoch of periastron passage 1,
[semimajor axis 2, eccentricity 2, etc.],
[total mass, parallax]
[parallax, total_mass]
where 1 corresponds to the first orbiting object, 2 corresponds
to the second, etc.
Expand Down Expand Up @@ -131,7 +129,7 @@ def __init__(self, num_secondary_bodies, data_table, system_mass,
self.sys_priors.append(priors.GaussianPrior(system_mass, mass_err))
else:
self.sys_priors.append(system_mass)

#add labels dictionary for parameter indexing
self.param_idx = dict(zip(self.labels, np.arange(len(self.labels))))

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ emcee
ptemcee
matplotlib
corner
h5py
h5py
deprecation
Loading

0 comments on commit 68567f3

Please sign in to comment.