diff --git a/README.rst b/README.rst index c775a09..0240173 100644 --- a/README.rst +++ b/README.rst @@ -7,7 +7,7 @@ Introduction :globalemu: Robust Global 21-cm Signal Emulation :Author: Harry Thomas Jones Bevins -:Version: 1.3.1 +:Version: 1.4.0 :Homepage: https://github.com/htjb/globalemu :Documentation: https://globalemu.readthedocs.io/ @@ -100,11 +100,9 @@ Results are accessed via 'res.z' and 'res.signal'. The code can also be used to train a network on your own Global 21-cm signal or neutral fraction simulations using the built in ``globalemu`` pre-processing techniques. There is some flexibility on the required astrophysical input -parameters but the models are required to subscribe to the astrophysics free -baseline calculation detailed in the ``globalemu`` paper (see below for a reference). +parameters and the pre-processing steps which is detailed in the documentation. More details about training your own network can be found in the documentation. - ``globalemu`` GUI ----------------- @@ -144,8 +142,9 @@ An image of the GUI is shown below. :alt: graphical user interface The GUI can also be used to investigate the physics of the neutral fraction -history by generating a configuration file for the released trained model and -setting the kwarg ``xHI=True`` in gui_config.config(). +history by generating a configuration file for the released trained model. +There is no need to specify that the configuration file is for a neutral +fraction emulator. Configuration files for the released models are provided. diff --git a/T_release/gui_configuration.csv b/T_release/gui_configuration.csv index 9e25d4d..5a0d90a 100644 --- a/T_release/gui_configuration.csv +++ b/T_release/gui_configuration.csv @@ -1,8 +1,8 @@ -names,mins,maxs,label_min,label_max,logs,xHI -$\log(f_*)$,-3.4579971262630043,-0.3010299956639812,-246.84562,32.171596,0,False -$\log(V_c)$,0.6232492903979004,1.8836614351536176,,,1, -$\log(f_X)$,-6.0,0.9977593286204041,,,2, -$\tau$,0.05550117,0.09999531,,,--, -$\alpha$,1.0,1.5,,,--, -$\nu_\mathrm{min}$,0.1,3.0,,,--, -$R_\mathrm{mfp}$,10.0,50.0,,,--, +names,mins,maxs,label_min,label_max,logs,ylabel +$\log(f_*)$,-3.4579971262630043,-0.3010299956639812,-246.84562,32.171596,0,$\delta T$ [mK] +$\log(V_c)$,0.6232492903979004,1.8836614351536176,,,1,$\delta T$ [mK] +$\log(f_X)$,-6.0,0.9977593286204041,,,2,$\delta T$ [mK] +$\tau$,0.05550117,0.09999531,,,--,$\delta T$ [mK] +$\alpha$,1.0,1.5,,,--,$\delta T$ [mK] +$\nu_\mathrm{min}$,0.1,3.0,,,--,$\delta T$ [mK] +$R_\mathrm{mfp}$,10.0,50.0,,,--,$\delta T$ [mK] diff --git a/T_release/preprocess_settings.pkl b/T_release/preprocess_settings.pkl new file mode 100644 index 0000000..8773bc6 Binary files /dev/null and b/T_release/preprocess_settings.pkl differ diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index af6b652..fff1372 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -9,7 +9,8 @@ of ``globalemu``. If you are just interested in evaluating the released models then take a look at the second part towards the bottom of the page. If you are intending to work with neutral fraction histories then the frame work for training and evaluating models is identical you just need to pass -the kwarg ``xHI=True`` to all of the ``globalemu`` functions. +the kwarg ``xHI=True`` to the pre-processing function, `process()`, +and model building function, `nn()`, discussed below. The tutorial can also be found as a Jupyter notebook `here `__. @@ -70,6 +71,12 @@ for the neural network. It also saves some files used for normalisation in the ``base_dir`` so that when evaluating the network the inputs and outputs can be properly dealt with. +By default the network subtracts and astrophysics free baseline from the models +and resamples the signals at a higher rate in regions of high variation across +the training data. Both of these pre-processing techniques are detailed in the +`globalemu` MNRAS preprint. Users can prevent this happening by passing the +kwargs `AFB=False` and `resampling=False` to `process()` if required. + Once pre-processing has been performed we can train our network with the ``nn()`` function in ``globalemu.network``. diff --git a/globalemu/downloads.py b/globalemu/downloads.py index da758b2..fa124e1 100644 --- a/globalemu/downloads.py +++ b/globalemu/downloads.py @@ -36,7 +36,8 @@ def model(self): files = [ 'model.h5', 'data_mins.txt', 'data_maxs.txt', 'samples.txt', - 'cdf.txt', 'z.txt', 'kwargs.txt', + 'cdf.txt', 'z.txt', 'kwargs.txt', 'preprocess_settings.pkl', + 'gui_configuration.csv', 'AFB_norm_factor.npy', 'labels_stds.npy', 'AFB.txt'] if self.xHI is False: @@ -49,7 +50,7 @@ def model(self): 'htjb/globalemu/master/xHI_release/' for i in range(len(files)): - if i > 6 and self.xHI is True: + if i > 8 and self.xHI is True: break r = requests.get(base_url + files[i]) open(base_dir + files[i], 'wb').write(r.content) diff --git a/globalemu/eval.py b/globalemu/eval.py index a88d42a..5358326 100644 --- a/globalemu/eval.py +++ b/globalemu/eval.py @@ -14,6 +14,7 @@ from tensorflow import keras from tensorflow.keras import backend as K import gc +import pickle class evaluate(): @@ -29,10 +30,6 @@ class evaluate(): **kwargs:** - xHI: **Bool / default: False** - | If True then ``globalemu`` will act as if it is evaluating a - neutral fraction history emulator. - base_dir: **string / default: 'model_dir/'** | The ``base_dir`` is where the trained model is saved. @@ -120,20 +117,22 @@ def __init__(self, **kwargs): for key, values in kwargs.items(): if key not in set( - ['xHI', 'base_dir', 'model', 'logs', 'gc', 'z']): + ['base_dir', 'model', 'logs', 'gc', 'z']): raise KeyError("Unexpected keyword argument in evaluate()") - self.xHI = kwargs.pop('xHI', False) - self.base_dir = kwargs.pop('base_dir', 'model_dir/') if type(self.base_dir) is not str: raise TypeError("'base_dir' must be a sting.") elif self.base_dir.endswith('/') is False: raise KeyError("'base_dir' must end with '/'.") + file = open(self.base_dir + "preprocess_settings.pkl", "rb") + self.preprocess_settings = pickle.load(file) + self.data_mins = np.loadtxt(self.base_dir + 'data_mins.txt') self.data_maxs = np.loadtxt(self.base_dir + 'data_maxs.txt') - self.cdf = np.loadtxt(self.base_dir + 'cdf.txt') + if self.preprocess_settings['resampling'] is True: + self.cdf = np.loadtxt(self.base_dir + 'cdf.txt') self.model = kwargs.pop('model', None) if self.model is None: @@ -146,14 +145,15 @@ def __init__(self, **kwargs): raise TypeError("'logs' must be a list.") self.garbage_collection = kwargs.pop('gc', False) - boolean_kwargs = [self.garbage_collection, self.xHI] - boolean_strings = ['gc', 'xHI'] + boolean_kwargs = [self.garbage_collection] + boolean_strings = ['gc'] for i in range(len(boolean_kwargs)): if type(boolean_kwargs[i]) is not bool: raise TypeError("'" + boolean_strings[i] + "' must be a bool.") - if self.xHI is False: + if self.preprocess_settings['AFB'] is True: self.AFB = np.loadtxt(self.base_dir + 'AFB.txt') + if self.preprocess_settings['std_division'] is True: self.label_stds = np.load(self.base_dir + 'labels_stds.npy') self.original_z = np.loadtxt(self.base_dir + 'z.txt') @@ -200,7 +200,11 @@ def __call__(self, parameters): (self.data_maxs[i] - self.data_mins[i]) for i in range(params.shape[1])]).T - norm_z = np.interp(self.z, self.original_z, self.cdf) + if self.preprocess_settings['resampling'] is True: + norm_z = np.interp(self.z, self.original_z, self.cdf) + else: + norm_z = (self.z - self.original_z.min()) / \ + (self.original_z.max() - self.original_z.min()) if isinstance(norm_z, np.ndarray): if len(normalised_params.shape) == 1: x = np.tile(normalised_params, (len(norm_z), 1)) @@ -229,7 +233,7 @@ def __call__(self, parameters): result = self.model(x[np.newaxis, :], training=False).numpy() evaluation = result[0][0] - if self.xHI is False: + if self.preprocess_settings['std_division'] is True: if isinstance(evaluation, np.ndarray): evaluation = [ evaluation[i]*self.label_stds @@ -237,6 +241,10 @@ def __call__(self, parameters): else: evaluation *= self.label_stds + if self.preprocess_settings['AFB'] is True: evaluation += np.interp(self.z, self.original_z, self.AFB) + if type(evaluation) is not np.ndarray: + evaluation = np.array(evaluation) + return evaluation, self.z diff --git a/globalemu/gui_config.py b/globalemu/gui_config.py index 4cdd013..1368242 100644 --- a/globalemu/gui_config.py +++ b/globalemu/gui_config.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +import pickle class config(): @@ -51,10 +52,6 @@ class config(): **Kwargs:** - xHI: **Bool / default: False** - | If True then ``globalemu`` will act as if it is evaluating a - neutral fraction history emulator. - logs: **list / default: [0, 1, 2]** | The indices corresponding to the astrophysical parameters that @@ -64,20 +61,23 @@ class config(): :math:`{V_c}` (minimum virial circular velocity) and :math:`{f_x}` (X-ray efficieny). + ylabel: **string / default: 'y'** + | y-axis label for gui plot. + """ def __init__(self, base_dir, paramnames, data_dir, **kwargs): for key, values in kwargs.items(): if key not in set( - ['xHI', 'logs']): + ['logs', 'ylabel']): raise KeyError("Unexpected keyword argument in process()") self.base_dir = base_dir self.paramnames = paramnames self.data_dir = data_dir self.logs = kwargs.pop('logs', [0, 1, 2]) - self.xHI = kwargs.pop('xHI', False) + self.ylabel = kwargs.pop('ylabel', 'y') file_kwargs = [self.base_dir, self.data_dir] file_strings = ['base_dir', 'data_dir'] @@ -87,12 +87,12 @@ def __init__(self, base_dir, paramnames, data_dir, **kwargs): elif file_kwargs[i].endswith('/') is False: raise KeyError("'" + file_strings[i] + "' must end with '/'.") + file = open(self.base_dir + "preprocess_settings.pkl", "rb") + self.preprocess_settings = pickle.load(file) + if type(self.paramnames) is not list: raise TypeError("'paramnames' must be a list of strings.") - if type(self.xHI) is not bool: - raise TypeError("'xHI' must be a bool.") - if type(self.logs) is not list: raise TypeError("'logs' must be a list.") @@ -123,7 +123,6 @@ def __init__(self, base_dir, paramnames, data_dir, **kwargs): 'label_max': [test_labels.max()] + ['']*(len(data_maxs)-1), 'logs': full_logs, - 'xHI': - [self.xHI] + ['']*(len(data_maxs)-1)}) + 'ylabel': self.ylabel}) df.to_csv(base_dir + 'gui_configuration.csv', index=False) diff --git a/globalemu/network.py b/globalemu/network.py index af4ecbc..1f3d67a 100644 --- a/globalemu/network.py +++ b/globalemu/network.py @@ -91,6 +91,15 @@ class nn(): | If True then ``globalemu`` will act as if it is training a neutral fraction history emulator. + output_activation: **string / default: 'linear'** + | Determines the output activation function for the network. + Modifying this + is useful if the emulator output is required to be positive or + negative etc. If xHI is True then the output activation is + set to 'relu' else the function is 'linear'. See the tensorflow + documentation for more details on the types of activation + functions available. + resume: **Bool / default: False** | If set to ``True`` then ``globalemu`` will look in the ``base_dir`` for a trained model and ``loss_history.txt`` @@ -123,7 +132,7 @@ def __init__(self, **kwargs): 'lr', 'dropout', 'input_shape', 'output_shape', 'layer_sizes', 'base_dir', 'early_stop', 'early_stop_lim', 'xHI', 'resume', - 'random_seed']): + 'random_seed', 'output_activation']): raise KeyError("Unexpected keyword argument in nn()") self.resume = kwargs.pop('resume', False) @@ -206,20 +215,19 @@ def pack_features_vector(features, labels): train_dataset = train_dataset.map(pack_features_vector) + self.output_activation = kwargs.pop('output_activation', 'linear') + if self.xHI is True: + self.output_activation = 'relu' + if self.resume is True: model = keras.models.load_model( self.base_dir + 'model.h5', compile=False) - elif self.xHI is False: - model = network_models().basic_model( - self.input_shape, self.output_shape, - self.layer_sizes, self.activation, self.drop_val, - 'linear') else: model = network_models().basic_model( self.input_shape, self.output_shape, self.layer_sizes, self.activation, self.drop_val, - 'relu') + self.output_activation) def loss(model, x, y, training): y_ = tf.transpose(model(x, training=training))[0] diff --git a/globalemu/preprocess.py b/globalemu/preprocess.py index 96d9dd5..837298b 100644 --- a/globalemu/preprocess.py +++ b/globalemu/preprocess.py @@ -14,6 +14,7 @@ import numpy as np import os import pandas as pd +import pickle from globalemu.cmSim import calc_signal from globalemu.resample import sampling @@ -51,6 +52,27 @@ class process(): | If True then ``globalemu`` will act as if it is training a neutral fraction history emulator. + AFB: **Bool / default: None** + | If True then ``globalemu`` will calculate an astrophysics free + baseline and subtract this from the training data signals. + The AFB is specific to the global 21-cm signal and as + ``globalemu`` is set up to emulate the global signal by + default this parameter is set to True. If xHI is True then + AFB is set to False by default. + + std_division: **Bool / default: None** + | If True then ``globalemu`` will divide the training data by the + standard deviation across the training data. This is + recommended when building an emulator to emulate the global + signal and is set to True by default. If xHI is True then + std_division is set to False by default. + + resampling: **Bool / default: None** + | Controls whether or not the signals will be resampled with + higher sampling at regions of large variation in the training + data set or not. Set to True by default as this is advised for + training both neutral fraction and global signal emulators. + logs: **list / default: [0, 1, 2]** | The indices corresponding to the astrophysical parameters in "train_data.txt" that need to be logged. The default assumes @@ -66,7 +88,8 @@ def __init__(self, num, z, **kwargs): for key, values in kwargs.items(): if key not in set( - ['base_dir', 'data_location', 'xHI', 'logs']): + ['base_dir', 'data_location', 'xHI', 'logs', 'AFB', + 'std_division', 'resampling']): raise KeyError("Unexpected keyword argument in process()") self.num = num @@ -90,8 +113,25 @@ def __init__(self, num, z, **kwargs): raise KeyError("'" + file_strings[i] + "' must end with '/'.") self.xHI = kwargs.pop('xHI', False) - if type(self.xHI) is not bool: - raise TypeError("'xHI' must be a bool.") + if self.xHI is False: + self.preprocess_settings = {'AFB': True, 'std_division': True, + 'resampling': True} + else: + self.preprocess_settings = {'AFB': False, 'std_division': False, + 'resampling': True} + + preprocess_settings_keys = ['AFB', 'std_division', 'resampling'] + for key in preprocess_settings_keys: + if key in kwargs: + self.preprocess_settings[key] = kwargs[key] + + bool_kwargs = [self.xHI, self.preprocess_settings['AFB'], + self.preprocess_settings['std_division'], + self.preprocess_settings['resampling']] + bool_strings = ['xHI', 'AFB', 'std_division', 'resampling'] + for i in range(len(bool_kwargs)): + if type(bool_kwargs[i]) is not bool: + raise TypeError(bool_strings[i] + " must be a bool.") self.logs = kwargs.pop('logs', [0, 1, 2]) if type(self.logs) is not list: @@ -100,6 +140,10 @@ def __init__(self, num, z, **kwargs): if not os.path.exists(self.base_dir): os.mkdir(self.base_dir) + file = open(self.base_dir + "preprocess_settings.pkl", "wb") + pickle.dump(self.preprocess_settings, file) + file.close() + np.savetxt(self.base_dir + 'z.txt', self.z) full_train_data = pd.read_csv( @@ -109,7 +153,7 @@ def __init__(self, num, z, **kwargs): self.data_location + 'train_labels.txt', delim_whitespace=True, header=None).values - if self.xHI is False: + if self.preprocess_settings['AFB'] is True: np.save( self.base_dir + 'AFB_norm_factor.npy', full_train_labels[0, -1]*1e-3) @@ -117,7 +161,7 @@ def __init__(self, num, z, **kwargs): if self.num == 'full': train_data = full_train_data.copy() - if self.xHI is False: + if self.preprocess_settings['AFB'] is True: train_labels = full_train_labels.copy() - res.deltaT else: train_labels = full_train_labels.copy() @@ -135,7 +179,7 @@ def __init__(self, num, z, **kwargs): for i in range(len(full_train_labels)): if np.any(ind == i): train_data.append(full_train_data[i, :]) - if self.xHI is False: + if self.preprocess_settings['AFB'] is True: train_labels.append(full_train_labels[i] - res.deltaT) else: train_labels.append(full_train_labels[i]) @@ -153,18 +197,21 @@ def __init__(self, num, z, **kwargs): log_td.append(train_data[:, i]) train_data = np.array(log_td).T - sampling_call = sampling( - self.z, self.base_dir, train_labels) - samples = sampling_call.samples - cdf = sampling_call.cdf + if self.preprocess_settings['resampling'] is True: + sampling_call = sampling( + self.z, self.base_dir, train_labels) + samples = sampling_call.samples + cdf = sampling_call.cdf - resampled_labels = [] - for i in range(len(train_labels)): - resampled_labels.append( - np.interp(samples, self.z, train_labels[i])) - train_labels = np.array(resampled_labels) + resampled_labels = [] + for i in range(len(train_labels)): + resampled_labels.append( + np.interp(samples, self.z, train_labels[i])) + train_labels = np.array(resampled_labels) - norm_s = np.interp(samples, self.z, cdf) + norm_s = np.interp(samples, self.z, cdf) + else: + norm_s = (self.z - self.z.min())/(self.z.max() - self.z.min()) data_mins = train_data.min(axis=0) data_maxs = train_data.max(axis=0) @@ -175,7 +222,7 @@ def __init__(self, num, z, **kwargs): (train_data[:, i] - data_mins[i])/(data_maxs[i]-data_mins[i])) norm_train_data = np.array(norm_train_data).T - if self.xHI is False: + if self.preprocess_settings['std_division'] is True: labels_stds = train_labels.std() norm_train_labels = [ train_labels[i, :]/labels_stds diff --git a/notebooks/Training.ipynb b/notebooks/Training.ipynb index 0c245a6..4427074 100644 --- a/notebooks/Training.ipynb +++ b/notebooks/Training.ipynb @@ -6,7 +6,7 @@ "source": [ "# Tutorial\n", "\n", - "This tutorial will show you the basics of training and evaluating an instance of ``globalemu``. If you are just interested in evaluating the released models then take a look at the second part towards the bottom of the page. If you are intending to work with neutral fraction histories then the frame work for training and evaluating models is identical you just need to pass the kwarg ``xHI=True`` to all of the ``globalemu`` functions.\n", + "This tutorial will show you the basics of training and evaluating an instance of ``globalemu``. If you are just interested in evaluating the released models then take a look at the second part towards the bottom of the page. If you are intending to work with neutral fraction histories then the frame work for training and evaluating models is identical you just need to pass the kwarg ``xHI=True`` to the pre-processing function, `process()`, and model building function, `nn()`, discussed below.\n", "\n", "## Training A Model\n", "\n", @@ -75,6 +75,8 @@ "\n", "Importantly the pre-processing function takes the data in ``data_dir`` and saves a ``.csv`` file in the ``base_dir`` containing the preprocessed inputs for the neural network. It also saves some files used for normalisation in the ``base_dir`` so that when evaluating the network the inputs and outputs can be properly dealt with.\n", "\n", + "By default the network subtracts and astrophysics free baseline from the models and resamples the signals at a higher rate in regions of high variation across the training data. Both of these pre-processing techniques are detailed in the `globalemu` MNRAS preprint. Users can prevent this happening by passing the kwargs `AFB=False` and `resampling=False` to `process()` if required.\n", + "\n", "Once pre-processing has been performed we can train our network with the ``nn()`` function in ``globalemu.network``." ] }, @@ -323,7 +325,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/scripts/globalemu b/scripts/globalemu index 66ab681..9b8c4cf 100644 --- a/scripts/globalemu +++ b/scripts/globalemu @@ -53,10 +53,7 @@ def signal(_, parameters=None): plt.figure(figsize=(4, 3)) plt.plot(z, signal, c='k') plt.xlabel('z') - if xHI is False: - plt.ylabel(r'$\delta T$ [mK]') - else: - plt.ylabel(r'$x_{HI}$') + plt.ylabel(ylabel) plt.ylim([label_min, label_max]) plt.tight_layout() plt.savefig('img/img.png', dpi=100) @@ -83,13 +80,13 @@ base_dir = sys.argv[1] config = pd.read_csv(base_dir + 'gui_configuration.csv') -xHI = config['xHI'][0] logs = config['logs'].tolist() logs = [int(x) for x in logs if x != '--'] label_min = config['label_min'][0] label_max = config['label_max'][0] +ylabel = config['ylabel'][0] -predictor = evaluate(base_dir=base_dir, xHI=xHI, logs=logs) +predictor = evaluate(base_dir=base_dir, logs=logs) window = Tk() window.geometry("800x450") diff --git a/setup.py b/setup.py index 53979b2..0eec250 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def readme(short=False): setup( name='globalemu', - version='1.3.1', + version='1.4.0', description='globalemu: Robust and Fast Global 21-cm Signal Emulation', long_description=readme(), author='Harry T. J. Bevins', diff --git a/tests/test_download.py b/tests/test_download.py index a6fea27..ad81732 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,7 +1,38 @@ from globalemu.downloads import download import os import pytest +import requests +import pandas as pd +import numpy as np +def download_21cmGEM_data(): + data_dir = '21cmGEM_data/' + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + files = ['Par_test_21cmGEM.txt', + 'Par_train_21cmGEM.txt', + 'T21_test_21cmGEM.txt', + 'T21_train_21cmGEM.txt'] + saves = ['test_data.txt', + 'train_data.txt', + 'test_labels.txt', + 'train_labels.txt'] + + for i in range(len(files)): + url = 'https://zenodo.org/record/4541500/files/' + files[i] + with open(data_dir + saves[i], 'wb') as f: + f.write(requests.get(url).content) + + td = pd.read_csv( + data_dir + 'train_data.txt', + delim_whitespace=True, header=None).values + tl = pd.read_csv( + data_dir + 'train_labels.txt', + delim_whitespace=True, header=None).values + + np.savetxt(data_dir + 'train_data.txt', td[:500, :]) + np.savetxt(data_dir + 'train_labels.txt', tl[:500, :]) def test_existing_dir(): if os.path.exists('kappa_HH.txt'): @@ -11,3 +42,6 @@ def test_existing_dir(): with pytest.raises(TypeError): download(xHI=2).kappa() + + # for use in later tests... + download_21cmGEM_data() diff --git a/tests/test_eval.py b/tests/test_eval.py index 962e9af..976843d 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -18,7 +18,8 @@ def test_existing_dir(): download().model() download(xHI=True).model() - predictor = evaluate(z=originalz, base_dir='T_release/') + predictor = evaluate(z=originalz, + base_dir='T_release/', gc=True) signal, z = predictor(params) assert(len(signal.shape) == 1) @@ -41,7 +42,12 @@ def test_existing_dir(): predictor = evaluate(z=10, base_dir='T_release/') signal, z = predictor(params) - predictor = evaluate(z=originalz, base_dir='xHI_release/', xHI=True) + preprocess = predictor.preprocess_settings + assert(preprocess['AFB'] is True) + assert(preprocess['resampling'] is True) + assert(preprocess['std_division'] is True) + + predictor = evaluate(z=originalz, base_dir='xHI_release/') signal, z = predictor(params) with pytest.raises(KeyError): @@ -59,5 +65,3 @@ def test_existing_dir(): with pytest.raises(TypeError): predictor = evaluate(z=originalz, base_dir='T_release/', gc='false') - with pytest.raises(TypeError): - predictor = evaluate(z=originalz, base_dir='T_release/', xHI='bar') diff --git a/tests/test_gui_config.py b/tests/test_gui_config.py index 2982f6f..cfe1281 100644 --- a/tests/test_gui_config.py +++ b/tests/test_gui_config.py @@ -7,26 +7,6 @@ import pytest -def download_21cmGEM_data(): - data_dir = '21cmGEM_data/' - if not os.path.exists(data_dir): - os.mkdir(data_dir) - - files = ['Par_test_21cmGEM.txt', - 'Par_train_21cmGEM.txt', - 'T21_test_21cmGEM.txt', - 'T21_train_21cmGEM.txt'] - saves = ['test_data.txt', - 'train_data.txt', - 'test_labels.txt', - 'train_labels.txt'] - - for i in range(len(files)): - url = 'https://zenodo.org/record/4541500/files/' + files[i] - with open(data_dir + saves[i], 'wb') as f: - f.write(requests.get(url).content) - - def test_config(): if os.path.exists('T_release/'): shutil.rmtree('T_release/') @@ -36,8 +16,6 @@ def test_config(): download().model() download(xHI=True).model() - download_21cmGEM_data() - paramnames = [r'$\log(f_*)$', r'$\log(V_c)$', r'$\log(f_X)$', r'$\nu_\mathrm{min}$', r'$\tau$', r'$\alpha$', r'$R_\mathrm{mfp}$'] @@ -45,12 +23,11 @@ def test_config(): # Providing this with global signal data as neutral fraction data is # not publicly available. Will not effect efficacy of the test. config('xHI_release/', paramnames, '21cmGEM_data/', - logs=[0, 1, 2], xHI=True) + logs=[0, 1, 2]) assert(os.path.exists('xHI_release/gui_configuration.csv') is True) res = pd.read_csv('xHI_release/gui_configuration.csv') - assert(res['xHI'][0] is True) logs = res['logs'].tolist() logs = [int(x) for x in logs if x != '--'] assert(logs == [0, 1, 2]) @@ -63,17 +40,12 @@ def test_config(): assert(os.path.exists('T_release/gui_configuration.csv') is True) - res = pd.read_csv('T_release/gui_configuration.csv') - assert(res['xHI'][0] is False) - with pytest.raises(KeyError): config('T_release', paramnames, '21cmGEM_data/') with pytest.raises(TypeError): config(10, paramnames, '21cmGEM_data/') with pytest.raises(KeyError): config('T_release/', paramnames, '21cmGEM_data') - with pytest.raises(TypeError): - config('T_release/', paramnames, '21cmGEM_data/', xHI=4) with pytest.raises(TypeError): config('T_release/', paramnames, '21cmGEM_data/', logs='banana') with pytest.raises(TypeError): @@ -81,5 +53,3 @@ def test_config(): with pytest.raises(KeyError): config('T_release/', paramnames, '21cmGEM_data/', color='C0') - - shutil.rmtree('21cmGEM_data/') diff --git a/tests/test_network.py b/tests/test_network.py index 2779ff6..3df4412 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -7,28 +7,7 @@ import pytest -def download_21cmGEM_data(): - data_dir = '21cmGEM_data/' - if not os.path.exists(data_dir): - os.mkdir(data_dir) - - files = ['Par_test_21cmGEM.txt', - 'Par_train_21cmGEM.txt', - 'T21_test_21cmGEM.txt', - 'T21_train_21cmGEM.txt'] - saves = ['test_data.txt', - 'train_data.txt', - 'test_labels.txt', - 'train_labels.txt'] - - for i in range(len(files)): - url = 'https://zenodo.org/record/4541500/files/' + files[i] - with open(data_dir + saves[i], 'wb') as f: - f.write(requests.get(url).content) - - def test_process_nn(): - download_21cmGEM_data() z = np.arange(5, 50.1, 0.1) process(10, z, data_location='21cmGEM_data/') @@ -39,6 +18,8 @@ def test_process_nn(): process(10, z, data_location='21cmGEM_data/', xHI=True) nn(batch_size=451, layer_sizes=[8], epochs=5, xHI=True) + nn(batch_size=451, layer_sizes=[8], epochs=5, output_activation='linear') + # test early_stop code nn(batch_size=451, layer_sizes=[], epochs=20, early_stop=True) @@ -76,12 +57,14 @@ def test_process_nn(): nn(xHI='false') with pytest.raises(TypeError): nn(resume=10) + with pytest.raises(TypeError): + nn(output_activation=2) process(10, z, data_location='21cmGEM_data/', base_dir='base_dir/') nn(batch_size=451, layer_sizes=[], random_seed=10, base_dir='base_dir/') - dir = ['21cmGEM_data/', 'model_dir/', 'base_dir/'] + dir = ['model_dir/', 'base_dir/'] for i in range(len(dir)): if os.path.exists(dir[i]): shutil.rmtree(dir[i]) diff --git a/tests/test_plotter.py b/tests/test_plotter.py index 84155ea..cecfc72 100644 --- a/tests/test_plotter.py +++ b/tests/test_plotter.py @@ -10,27 +10,6 @@ params = [0.25, 30, 2, 0.056, 1.3, 2, 30] z = np.arange(10, 20, 100) - -def download_21cmGEM_data(): - data_dir = '21cmGEM_data/' - if not os.path.exists(data_dir): - os.mkdir(data_dir) - - files = ['Par_test_21cmGEM.txt', - 'Par_train_21cmGEM.txt', - 'T21_test_21cmGEM.txt', - 'T21_train_21cmGEM.txt'] - saves = ['test_data.txt', - 'train_data.txt', - 'test_labels.txt', - 'train_labels.txt'] - - for i in range(len(files)): - url = 'https://zenodo.org/record/4541500/files/' + files[i] - with open(data_dir + saves[i], 'wb') as f: - f.write(requests.get(url).content) - - def test_existing_dir(): if os.path.exists('T_release/'): shutil.rmtree('T_release/') @@ -42,8 +21,6 @@ def test_existing_dir(): predictor = evaluate(base_dir='T_release/') - download_21cmGEM_data() - parameters = np.loadtxt('21cmGEM_data/test_data.txt') labels = np.loadtxt('21cmGEM_data/test_labels.txt') @@ -104,6 +81,3 @@ def loss_func(labels, signals): signal_plot(parameters, labels, 'GEMLoss', predictor, 'T_release/') - - if os.path.exists('21cmGEM_data/'): - shutil.rmtree('21cmGEM_data/') diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 4a50320..2850b71 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -7,35 +7,19 @@ import shutil -def download_21cmGEM_data(): - data_dir = '21cmGEM_data/' - if not os.path.exists(data_dir): - os.mkdir(data_dir) - - files = ['Par_test_21cmGEM.txt', - 'Par_train_21cmGEM.txt', - 'T21_test_21cmGEM.txt', - 'T21_train_21cmGEM.txt'] - saves = ['test_data.txt', - 'train_data.txt', - 'test_labels.txt', - 'train_labels.txt'] - - for i in range(len(files)): - url = 'https://zenodo.org/record/4541500/files/' + files[i] - with open(data_dir + saves[i], 'wb') as f: - f.write(requests.get(url).content) - - -def test_process_nn(): - download_21cmGEM_data() +def test_preprocess(): z = np.arange(5, 50.1, 0.1) process(10, z, data_location='21cmGEM_data/') + process(10, z, data_location='21cmGEM_data/', xHI=True) + process('full', z, data_location='21cmGEM_data/') + process('full', z, data_location='21cmGEM_data/', AFB=False) + process(10, z, data_location='21cmGEM_data/', resampling=False) files = ['AFB_norm_factor.npy', 'AFB.txt', 'cdf.txt', 'data_maxs.txt', 'data_mins.txt', 'indices.txt', 'labels_stds.npy', 'samples.txt', - 'train_data.txt', 'train_dataset.csv', 'train_label.txt', 'z.txt'] + 'train_data.txt', 'train_dataset.csv', 'train_label.txt', 'z.txt', + 'preprocess_settings.pkl'] for i in range(len(files)): assert(os.path.exists('model_dir/' + files[i]) is True) @@ -58,10 +42,14 @@ def test_process_nn(): process(10, z, data_location='data_download') with pytest.raises(TypeError): process(10, z, data_location='21cmGEM_data/', base_dir=10) - with pytest.raises(TypeError): - process(10, z, data_location='21cmGEM_data/', xHI=10) with pytest.raises(TypeError): process(10, z, data_location='21cmGEM_data/', logs=True) + with pytest.raises(TypeError): + process(10, z, data_location='21cmGEM_data/', AFB=10) + with pytest.raises(TypeError): + process(10, z, data_location='21cmGEM_data/', resampling=10) + with pytest.raises(TypeError): + process(10, z, data_location='21cmGEM_data/', resampling=10) dir = ['21cmGEM_data/', 'model_dir/'] for i in range(len(dir)): diff --git a/xHI_release/gui_configuration.csv b/xHI_release/gui_configuration.csv index 4802dbc..91c327c 100644 --- a/xHI_release/gui_configuration.csv +++ b/xHI_release/gui_configuration.csv @@ -1,8 +1,8 @@ -names,mins,maxs,label_min,label_max,logs,xHI -$\log(f_*)$,-3.7328977171294224,-0.3010299956639812,0.0,0.9998003620503403,0,True -$\log(V_c)$,0.6232492903979004,1.8836614351536176,,,1, -$\log(f_X)$,-4.0,1.0,,,2, -$\nu_\mathrm{min}$,0.1,3.0,,,--, -$\tau$,0.05509129840335715,0.08820784364958727,,,--, -$\alpha$,1.0,1.5,,,--, -$R_\mathrm{mfp}$,10.0,50.0,,,--, +names,mins,maxs,label_min,label_max,logs,ylabel +$\log(f_*)$,-3.7328977171294224,-0.3010299956639812,0.0,0.9998003620503403,0,$x_{HI}$ +$\log(V_c)$,0.6232492903979004,1.8836614351536176,,,1,$x_{HI}$ +$\log(f_X)$,-4.0,1.0,,,2,$x_{HI}$ +$\nu_\mathrm{min}$,0.1,3.0,,,--,$x_{HI}$ +$\tau$,0.05509129840335715,0.08820784364958727,,,--,$x_{HI}$ +$\alpha$,1.0,1.5,,,--,$x_{HI}$ +$R_\mathrm{mfp}$,10.0,50.0,,,--,$x_{HI}$ diff --git a/xHI_release/preprocess_settings.pkl b/xHI_release/preprocess_settings.pkl new file mode 100644 index 0000000..8fe6d1f Binary files /dev/null and b/xHI_release/preprocess_settings.pkl differ