diff --git a/.travis.yml b/.travis.yml index 98113749..4b20b8b2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,7 +25,7 @@ before_install: install: - conda install --yes python=$TRAVIS_PYTHON_VERSION pip numpy scipy cython matplotlib - - pip install pytest pytest-cov coverage coveralls corner + - pip install pytest pytest-cov coverage coveralls corner deprecation - pip install . --upgrade - python setup.py build_ext -i diff --git a/docs/formatting_inputs.rst b/docs/formatting_inputs.rst index 95f6c6b6..da8923b4 100644 --- a/docs/formatting_inputs.rst +++ b/docs/formatting_inputs.rst @@ -3,7 +3,7 @@ Formatting Input ++++++++++++++++ -``orbitize.read_input.read_formatted_file()`` handles reading -in data. Check out the documentation for this method -`here `_ for details about units, acceptable formats, -and more. \ No newline at end of file +``orbitize.read_input.read_file()`` handles reading +in data. Check out the documentation for this method +`here `_ for details about units, acceptable formats, +and more. diff --git a/orbitize/driver.py b/orbitize/driver.py index 0e4c8097..d64e5156 100644 --- a/orbitize/driver.py +++ b/orbitize/driver.py @@ -3,7 +3,7 @@ import orbitize.sampler """ -This module reads input and constructs ``orbitize`` objects +This module reads input and constructs ``orbitize`` objects in a standardized way. """ @@ -12,10 +12,11 @@ class Driver(object): Runs through ``orbitize`` methods in a standardized way. Args: - filename (str): relative path to data file. See ``orbitize.read_input`` - sampler_str (str): algorithm to use for orbit computation. "MCMC" for + input_data: Either a relative path to data file or astropy.table.Table object + in the orbitize format. See ``orbitize.read_input`` + sampler_str (str): algorithm to use for orbit computation. "MCMC" for Markov Chain Monte Carlo, "OFTI" for Orbits for the Impatient - num_secondary_bodies (int): number of secondary bodies in the system. + num_secondary_bodies (int): number of secondary bodies in the system. Should be at least 1. system_mass (float): mean total mass of the system [M_sol] plx (float): mean parallax of the system [mas] @@ -23,21 +24,30 @@ class Driver(object): plx_err (float, optional): uncertainty on ``plx`` [mas] lnlike (str, optional): name of function in ``orbitize.lnlike`` that will be used to compute likelihood. (default="chi2_lnlike") - mcmc_kwargs (dict, optional): ``num_temps``, ``num_walkers``, and ``num_threads`` + mcmc_kwargs (dict, optional): ``num_temps``, ``num_walkers``, and ``num_threads`` kwargs for ``orbitize.sampler.MCMC`` Written: Sarah Blunt, 2018 """ - def __init__(self, filename, sampler_str, - num_secondary_bodies, system_mass, plx, + def __init__(self, input_data, sampler_str, + num_secondary_bodies, system_mass, plx, mass_err=0, plx_err=0, lnlike='chi2_lnlike', mcmc_kwargs=None): # Read in data - data_table = orbitize.read_input.read_formatted_file(filename) + # Try to interpret input as a filename first + try: + data_table = orbitize.read_input.read_file(input_data) + except: + try: + # Check if input might be an orbitize style astropy.table.Table + if 'quant_type' in input_data.columns: + data_table = input_data.copy() + except: + raise Exception('Invalid value of input_data for Driver') # Initialize System object which stores data & sets priors self.system = orbitize.system.System( - num_secondary_bodies, data_table, system_mass, + num_secondary_bodies, data_table, system_mass, plx, mass_err=mass_err, plx_err=plx_err ) diff --git a/orbitize/read_input.py b/orbitize/read_input.py index 4b06221f..8c339f6f 100644 --- a/orbitize/read_input.py +++ b/orbitize/read_input.py @@ -2,16 +2,22 @@ Module to read user input from files and create standardized input for orbitize """ +import deprecation import numpy as np +import orbitize from astropy.table import Table from astropy.io.ascii import read, write -def read_formatted_file(filename): - """ Reads data from any file +def read_file(filename): + """ Reads data from any file for use in orbitize readable by ``astropy.io.ascii.read()``, including csv format. See the `astropy docs `_. - Here is an example of an orbitize-readable .csv input file:: + There are two ways to provide input data to orbitize. + + The first way is to provide astrometric measurements, shown with the following example. + + Example of an orbitize-readable .csv input file:: epoch,object,raoff,raoff_err,decoff,decoff_err,sep,sep_err,pa,pa_err,rv,rv_err 1234,1,0.010,0.005,0.50,0.05,,,,,, @@ -19,11 +25,11 @@ def read_formatted_file(filename): 1236,1,,,,,1.0,0.005,89.3,0.3,, 1237,0,,,,,,,,,10,0.1 - Each row must have ``epoch`` (in MJD=JD-2400000.5) and ``object``. - Objects are numbered with integers, where the primary/central object is ``0``. - If you have, for example, one RV measurement of a star and three astrometric - measurements of an orbiting planet, you should put ``0`` in the ``object`` column - for the RV point, and ``1`` in the columns for the astrometric measurements. + Each row must have ``epoch`` (in MJD=JD-2400000.5) and ``object``. + Objects are numbered with integers, where the primary/central object is ``0``. + If you have, for example, one RV measurement of a star and three astrometric + measurements of an orbiting planet, you should put ``0`` in the ``object`` column + for the RV point, and ``1`` in the columns for the astrometric measurements. Each line must also have at least one of the following sets of valid measurements: @@ -34,15 +40,23 @@ def read_formatted_file(filename): .. Note:: Columns with no data can be omitted (e.g. if only separation and PA are given, the raoff, deoff, and rv columns can be excluded). - If more than one valid set is given (e.g. RV measurement and astrometric measurement - taken at the same epoch), ``read_formatted_file()`` will generate a separate output - row for each valid set. + If more than one valid set is given (e.g. RV measurement and astrometric measurement + taken at the same epoch), ``read_file()`` will generate a separate output row for + each valid set. .. Warning:: For now, ``orbitize`` only accepts astrometric measurements for one secondary body. In a future release, it will also handle astrometric measurements for multiple secondaries, RV measurements of the primary and secondar(ies), and astrometric measurements of the primary. Stay tuned! + Alternatively, you can also supply a data file with the columns already corresponding to + the orbitize format (see the example in description of what this method returns). This may + be useful if you are wanting to use the output of the `write_orbitize_input` method. + + .. Note:: When providing data with columns in the orbitize format, there should be no + empty cells. As in the example below, when quant2 is not applicable, the cell should + contain nan. + Args: filename (str): Input file name @@ -65,7 +79,6 @@ def read_formatted_file(filename): Written: Henry Ngo, 2018 """ - # initialize output table output_table = Table(names=('epoch','object','quant1','quant1_err','quant2','quant2_err','quant_type'), dtype=(float,int,float,float,float,float,'S5')) @@ -74,9 +87,12 @@ def read_formatted_file(filename): input_table = read(filename) num_measurements = len(input_table) + # Decide if input was given in the orbitize style + orbitize_style = 'quant_type' in input_table.columns + # validate input # if input_table is Masked, then figure out which entries are masked - # otherwise, just check that we have the required columns + # otherwise, just check that we have the required columns based on orbitize_style flag if input_table.masked: if 'epoch' in input_table.columns: have_epoch = ~input_table['epoch'].mask @@ -90,56 +106,59 @@ def read_formatted_file(filename): raise Exception("Invalid input format: missing some object entries") else: raise Exception("Input table MUST have object id!") - if 'raoff' in input_table.columns: - have_ra = ~input_table['raoff'].mask - else: - have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'decoff' in input_table.columns: - have_dec = ~input_table['decoff'].mask - else: - have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'sep' in input_table.columns: - have_sep = ~input_table['sep'].mask - else: - have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'pa' in input_table.columns: - have_pa = ~input_table['pa'].mask - else: - have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'rv' in input_table.columns: - have_rv = ~input_table['rv'].mask - else: - have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False + if orbitize_style: # proper orbitize style should NEVER have masked entries (nan required) + raise Exception("Input table in orbitize style may NOT have empty cells") + else: # Check for these things when not orbitize style + if 'raoff' in input_table.columns: + have_ra = ~input_table['raoff'].mask + else: + have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'decoff' in input_table.columns: + have_dec = ~input_table['decoff'].mask + else: + have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'sep' in input_table.columns: + have_sep = ~input_table['sep'].mask + else: + have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'pa' in input_table.columns: + have_pa = ~input_table['pa'].mask + else: + have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'rv' in input_table.columns: + have_rv = ~input_table['rv'].mask + else: + have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False else: # no masked entries, just check for required columns if 'epoch' not in input_table.columns: raise Exception("Input table MUST have epoch!") if 'object' not in input_table.columns: raise Exception("Input table MUST have object id!") - if 'raoff' in input_table.columns: - have_ra = np.ones(num_measurements, dtype=bool) # ones are False - else: - have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'decoff' in input_table.columns: - have_dec = np.ones(num_measurements, dtype=bool) # ones are False - else: - have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'sep' in input_table.columns: - have_sep = np.ones(num_measurements, dtype=bool) # ones are False - else: - have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'pa' in input_table.columns: - have_pa = np.ones(num_measurements, dtype=bool) # ones are False - else: - have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False - if 'rv' in input_table.columns: - have_rv = np.ones(num_measurements, dtype=bool) # ones are False - else: - have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False + if not orbitize_style: # Set these flags only when not already in orbitize style + if 'raoff' in input_table.columns: + have_ra = np.ones(num_measurements, dtype=bool) # ones are False + else: + have_ra = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'decoff' in input_table.columns: + have_dec = np.ones(num_measurements, dtype=bool) # ones are False + else: + have_dec = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'sep' in input_table.columns: + have_sep = np.ones(num_measurements, dtype=bool) # ones are False + else: + have_sep = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'pa' in input_table.columns: + have_pa = np.ones(num_measurements, dtype=bool) # ones are False + else: + have_pa = np.zeros(num_measurements, dtype=bool) # zeros are False + if 'rv' in input_table.columns: + have_rv = np.ones(num_measurements, dtype=bool) # ones are False + else: + have_rv = np.zeros(num_measurements, dtype=bool) # zeros are False # loop through each row and format table index=0 for row in input_table: - # check epoch format and put in MJD if row['epoch'] > 2400000.5: # assume this is in JD MJD = row['epoch'] - 2400000.5 @@ -151,16 +170,38 @@ def read_formatted_file(filename): raise Exception("Invalid object ID. Object IDs must be integers.") # determine input quantity type (RA/DEC, SEP/PA, or RV) - if have_ra[index] and have_dec[index]: - output_table.add_row([MJD, row['object'], row['raoff'], row['raoff_err'], row['decoff'], row['decoff_err'], "radec"]) - elif have_sep[index] and have_pa[index]: - output_table.add_row([MJD, row['object'], row['sep'], row['sep_err'], row['pa'], row['pa_err'], "seppa"]) - if have_rv[index]: - output_table.add_row([MJD, row['object'], row['rv'], row['rv_err'], None, None, "rv"]) + if orbitize_style: + if row['quant_type'] == 'rv': # special format for rv rows + output_table.add_row([MJD, row['object'], row['quant1'], row['quant1_err'], None, None, row['quant_type']]) + elif row['quant_type'] == 'radec' or row['quant_type'] == 'seppa': # other allowed formats + output_table.add_row([MJD, row['object'], row['quant1'], row['quant1_err'], row['quant2'], row['quant2_err'], row['quant_type']]) + else: # catch wrong formats + raise Exception("Invalid 'quant_type'. Valid values are 'radec', 'seppa' or 'rv'") + else: # When not in orbitize style + if have_ra[index] and have_dec[index]: + output_table.add_row([MJD, row['object'], row['raoff'], row['raoff_err'], row['decoff'], row['decoff_err'], "radec"]) + elif have_sep[index] and have_pa[index]: + output_table.add_row([MJD, row['object'], row['sep'], row['sep_err'], row['pa'], row['pa_err'], "seppa"]) + if have_rv[index]: + output_table.add_row([MJD, row['object'], row['rv'], row['rv_err'], None, None, "rv"]) + index=index+1 return output_table +@deprecation.deprecated(deprecated_in="1.0.2", removed_in="2.0", + current_version=orbitize.__version__, + details="Use read_file() instead. v1.0.2 replaces read_formatted_file and read_orbitize_input with read_file(). For now, this will be a wrapper for read_file and will be removed in the v2.0 release.") +def read_formatted_file(filename): + """ + Version 1.0.2 replaces this function with `read_file`. + Currently exists as a wrapper for `read_file` and will be removed in v2.0 + + Written: Henry Ngo, 2018 + """ + + return read_file(filename) + def write_orbitize_input(table,output_filename,file_type='csv'): """ Writes orbitize-readable input as an ASCII file @@ -168,7 +209,7 @@ def write_orbitize_input(table,output_filename,file_type='csv'): table (astropy.Table): Table containing orbitize-readable input for given object, as generated by the read functions in this module. output_filename (str): csv file to write to - file_type (str): Any valid write format for astropy.io.ascii. See the + file_type (str): Any valid write format for astropy.io.ascii. See the `astropy docs `_. Defaults to csv. @@ -185,25 +226,14 @@ def write_orbitize_input(table,output_filename,file_type='csv'): # write file write(table,output=output_filename,format=file_type) +@deprecation.deprecated(deprecated_in="1.0.2", removed_in="2.0", + current_version=orbitize.__version__, + details="Use read_file() instead. v1.0.2 replaces read_orbitize_input and read_formatted_file with read_file(). For now, this will be a wrapper for read_file and will be removed in the v2.0 release.") def read_orbitize_input(filename): - """ Reads orbitize-readable input from a correctly formatted ASCII file - - Args: - filename (str): Name of file to read. It should have columns - indicated in the table below. - - Returns: - astropy.Table: Table containing orbitize-readable input for given - object. Columns returned are:: - - epoch, object, quant1, quant1_err, quant2, quant2_err, quant_type - - where ``quant_type`` is one of "radec", "seppa", or "rv". - - If ``quant_type`` is "radec" or "seppa", the units of quant are mas and degrees, - if ``quant_type`` is "rv", the units of quant are km/s + """ + Version 1.0.2 replaces this function with `read_file`. + Currently exists as a wrapper for `read_file` and will be removed in v2.0 - (written) Henry Ngo, 2018 + Written: Henry Ngo, 2018 """ - output_table=read(filename) - return output_table + return read_file(filename) diff --git a/orbitize/system.py b/orbitize/system.py index b0b94ab1..455a2ddf 100644 --- a/orbitize/system.py +++ b/orbitize/system.py @@ -10,9 +10,7 @@ class System(object): Args: num_secondary_bodies (int): number of secondary bodies in the system. Should be at least 1. - data_table (astropy.table.Table): output from either - ``orbitize.read_input.read_formatted_file()`` or - ``orbitize.read_input.read_orbitize_input()`` + data_table (astropy.table.Table): output from ``orbitize.read_input.read_file()`` system_mass (float): mean total mass of the system, in M_sol plx (float): mean parallax of the system, in mas mass_err (float, optional): uncertainty on ``system_mass``, in M_sol @@ -33,7 +31,7 @@ class System(object): argument of periastron 1, position angle of nodes 1, epoch of periastron passage 1, [semimajor axis 2, eccentricity 2, etc.], - [total mass, parallax] + [parallax, total_mass] where 1 corresponds to the first orbiting object, 2 corresponds to the second, etc. @@ -131,7 +129,7 @@ def __init__(self, num_secondary_bodies, data_table, system_mass, self.sys_priors.append(priors.GaussianPrior(system_mass, mass_err)) else: self.sys_priors.append(system_mass) - + #add labels dictionary for parameter indexing self.param_idx = dict(zip(self.labels, np.arange(len(self.labels)))) diff --git a/requirements.txt b/requirements.txt index a323fdbd..30ece468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ emcee ptemcee matplotlib corner -h5py \ No newline at end of file +h5py +deprecation diff --git a/tests/test_api.py b/tests/test_api.py index d5896417..5816bd9c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -13,7 +13,7 @@ def test_compute_model(): """ testdir = os.path.dirname(os.path.abspath(__file__)) input_file = os.path.join(testdir, 'test_val.csv') - data_table = read_input.read_formatted_file(input_file) + data_table = read_input.read_file(input_file) data_table['object'] = 1 testSystem_parsing = system.System( 1, data_table, 10., 10. @@ -41,7 +41,7 @@ def test_systeminit(): """ testdir = os.path.dirname(os.path.abspath(__file__)) input_file = os.path.join(testdir, 'test_val.csv') - data_table = read_input.read_formatted_file(input_file) + data_table = read_input.read_file(input_file) # Manually set 'object' column of data table data_table['object'] = 1 diff --git a/tests/test_driver.py b/tests/test_driver.py new file mode 100644 index 00000000..2cb0ba8e --- /dev/null +++ b/tests/test_driver.py @@ -0,0 +1,92 @@ +""" +Test the different Driver class creation options +""" + +import pytest +import numpy as np +from orbitize import driver +from orbitize.read_input import read_file +import os + +def _compare_table(input_table): + """ + Tests input table to expected values, which are: + epoch object quant1 quant1_err quant2 quant2_err quant_type + float64 int float64 float64 float64 float64 str5 + ------- ------ ------- ---------- ------- ---------- ---------- + 1234.0 1 0.01 0.005 0.5 0.05 radec + 1235.0 1 1.0 0.005 89.0 0.1 seppa + 1236.0 1 1.0 0.005 89.3 0.3 seppa + 1237.0 0 10.0 0.1 nan nan rv + """ + rows_expected = 4 + epoch_expected = [1234, 1235, 1236, 1237] + object_expected = [1,1,1,0] + quant1_expected = [0.01, 1.0, 1.0, 10.0] + quant1_err_expected = [0.005, 0.005, 0.005, 0.1] + quant2_expected = [0.5, 89.0, 89.3, np.nan] + quant2_err_expected = [0.05, 0.1, 0.3, np.nan] + quant_type_expected = ['radec', 'seppa', 'seppa', 'rv'] + assert len(input_table) == rows_expected + for meas,truth in zip(input_table['epoch'],epoch_expected): + assert truth == pytest.approx(meas) + for meas,truth in zip(input_table['object'],object_expected): + assert truth == meas + for meas,truth in zip(input_table['quant1'],quant1_expected): + if np.isnan(truth): + assert np.isnan(meas) + else: + assert truth == pytest.approx(meas) + for meas,truth in zip(input_table['quant1_err'],quant1_err_expected): + if np.isnan(truth): + assert np.isnan(meas) + else: + assert truth == pytest.approx(meas) + for meas,truth in zip(input_table['quant2'],quant2_expected): + if np.isnan(truth): + assert np.isnan(meas) + else: + assert truth == pytest.approx(meas) + for meas,truth in zip(input_table['quant2_err'],quant2_err_expected): + if np.isnan(truth): + assert np.isnan(meas) + else: + assert truth == pytest.approx(meas) + for meas,truth in zip(input_table['quant_type'],quant_type_expected): + assert truth == meas + +def test_create_driver_from_filename(): + """ + Test creation of Driver object from filename as input + """ + testdir = os.path.dirname(os.path.abspath(__file__)) + input_file = os.path.join(testdir, 'test_val.csv') + myDriver = driver.Driver(input_file, # path to data file + 'MCMC', # name of algorith for orbit-fitting + 1, # number of secondary bodies in system + 1.0, # total system mass [M_sun] + 50.0, # total parallax of system [mas] + mass_err=0.1, # mass error [M_sun] + plx_err=0.1) # parallax error [mas] + _compare_table(myDriver.system.data_table) + + +def test_create_driver_from_table(): + """ + Test creation of Driver object from Table as input + """ + testdir = os.path.dirname(os.path.abspath(__file__)) + input_file = os.path.join(testdir, 'test_val.csv') + input_table = read_file(input_file) + myDriver = driver.Driver(input_table, # astropy.table Table of input + 'MCMC', # name of algorith for orbit-fitting + 1, # number of secondary bodies in system + 1.0, # total system mass [M_sun] + 50.0, # total parallax of system [mas] + mass_err=0.1, # mass error [M_sun] + plx_err=0.1) # parallax error [mas] + _compare_table(myDriver.system.data_table) + +if __name__ == '__main__': + test_create_driver_from_filename() + test_create_driver_from_table() diff --git a/tests/test_mcmc.py b/tests/test_mcmc.py index f0e96b10..1358994c 100644 --- a/tests/test_mcmc.py +++ b/tests/test_mcmc.py @@ -10,7 +10,7 @@ def test_pt_mcmc_runs(num_threads=1): # use the test_csv dir testdir = os.path.dirname(os.path.abspath(__file__)) input_file = os.path.join(testdir, 'test_val.csv') - data_table = read_input.read_formatted_file(input_file) + data_table = read_input.read_file(input_file) # Manually set 'object' column of data table data_table['object'] = 1 @@ -36,7 +36,7 @@ def test_ensemble_mcmc_runs(num_threads=1): # use the test_csv dir testdir = os.path.dirname(os.path.abspath(__file__)) input_file = os.path.join(testdir, 'test_val.csv') - data_table = read_input.read_formatted_file(input_file) + data_table = read_input.read_file(input_file) # Manually set 'object' column of data table data_table['object'] = 1 diff --git a/tests/test_read_input.py b/tests/test_read_input.py index 8b01ea3d..3848345e 100644 --- a/tests/test_read_input.py +++ b/tests/test_read_input.py @@ -1,7 +1,8 @@ import pytest +import deprecation import numpy as np import os -from orbitize.read_input import read_formatted_file, write_orbitize_input, read_orbitize_input +from orbitize.read_input import read_file, write_orbitize_input, read_formatted_file, read_orbitize_input def _compare_table(input_table): @@ -51,9 +52,24 @@ def _compare_table(input_table): for meas,truth in zip(input_table['quant_type'],quant_type_expected): assert truth == meas +def test_read_file(): + """ + Test the read_file function using the test_val.csv file and test_val_radec.csv + """ + testdir = os.path.dirname(os.path.abspath(__file__)) + # Check that main test input is read in with correct values + input_file = os.path.join(testdir, 'test_val.csv') + _compare_table(read_file(input_file)) + # Check that an input value with all valid entries and only ra/dec columns can be read + input_file_radec = os.path.join(testdir, 'test_val_radec.csv') + read_file(input_file_radec) + +@deprecation.fail_if_not_removed def test_read_formatted_file(): """ - Test the read_formatted_file function using the test_val.csv file and test_val_radec.csv + Tests the read_formatted_file function using the test_val.csv file and test_val_radec.csv + + This test exists with the fail_if_not_removed decorator as a reminder to remove in v2.0 """ testdir = os.path.dirname(os.path.abspath(__file__)) # Check that main test input is read in with correct values @@ -61,15 +77,40 @@ def test_read_formatted_file(): _compare_table(read_formatted_file(input_file)) # Check that an input value with all valid entries and only ra/dec columns can be read input_file_radec = os.path.join(testdir, 'test_val_radec.csv') - read_formatted_file(input_file_radec) + read_file(input_file_radec) + +def test_write_orbitize_input(): + """ + Test the write_orbitize_input and the read_file functions + """ + testdir = os.path.dirname(os.path.abspath(__file__)) + input_file = os.path.join(testdir, 'test_val.csv') + test_table = read_file(input_file) + output_file = os.path.join(testdir, 'temp_test_orbitize_input.csv') + # If temp output file already exists, delete it + if os.path.isfile(output_file): + os.remove(output_file) + try: # Catch these tests so that we remove temporary file + # Test that we were able to write the table + write_orbitize_input(test_table,output_file) + assert os.path.isfile(output_file) + # Test that we can read the table and check if it's correct + test_table_2 = read_file(output_file) + _compare_table(test_table_2) + finally: + # Remove temporary file + os.remove(output_file) -def test_write_read_orbitize_input(): +@deprecation.fail_if_not_removed +def test_write_orbitize_input_2(): """ Test the write_orbitize_input and the read_orbitize_input functions + + This test exists with the fail_if_not_removed decorator as a reminder to remove in v2.0 """ testdir = os.path.dirname(os.path.abspath(__file__)) input_file = os.path.join(testdir, 'test_val.csv') - test_table = read_formatted_file(input_file) + test_table = read_file(input_file) output_file = os.path.join(testdir, 'temp_test_orbitize_input.csv') # If temp output file already exists, delete it if os.path.isfile(output_file): @@ -86,5 +127,7 @@ def test_write_read_orbitize_input(): os.remove(output_file) if __name__ == "__main__": + test_read_file() test_read_formatted_file() - test_write_read_orbitize_input() + test_write_orbitize_input() + test_write_orbitize_input_2() diff --git a/tests/test_system.py b/tests/test_system.py index 8788ac41..dc65fb60 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -11,7 +11,7 @@ def test_add_and_clear_results(): num_secondary_bodies=1 testdir = os.path.dirname(os.path.abspath(__file__)) input_file = os.path.join(testdir, 'test_val.csv') - data_table=read_input.read_formatted_file(input_file) + data_table=read_input.read_file(input_file) system_mass=1.0 plx=10.0 mass_err=0.1