-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* major update to the gui * bug fix in gui * function to generate gui config file * switched from argparser to sys.argv * docstring added for gui_config.py * bug fixs * bumping version number and editing README GUI description * gui configuration files for released models * flake8 tidy * bug fix in gui_config/update to T_release gui_config file * kwarg catches added to gui_config and relevant doc strings to docs * bug fixes and tests for gui_config * bug fix in tests
- Loading branch information
Showing
8 changed files
with
345 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
names,mins,maxs,label_min,label_max,logs,xHI | ||
$\log(f_*)$,-3.4579971262630043,-0.3010299956639812,-246.84562,32.171596,0,False | ||
$\log(V_c)$,0.6232492903979004,1.8836614351536176,,,1, | ||
$\log(f_X)$,-6.0,0.9977593286204041,,,2, | ||
$\tau$,0.05550117,0.09999531,,,--, | ||
$\alpha$,1.0,1.5,,,--, | ||
$\nu_\mathrm{min}$,0.1,3.0,,,--, | ||
$R_\mathrm{mfp}$,10.0,50.0,,,--, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
""" | ||
This function can be used to generate a configurate file for the GUI that | ||
is specific to a given trained model. The file gets saved into the supplied | ||
``base_dir`` which should contain the relevant trained model. The user | ||
also needs to supply a path to the ``data_dir`` that contains the relevant | ||
testing and training data. Additional arguments are described below. | ||
A GUI config file is required to be able to visualise the signals with the | ||
GUI and once generated the gui can be run from the command line | ||
.. code:: bash | ||
globalemu /path/to/base_dir/containing/model/and/config/ | ||
""" | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class config(): | ||
|
||
r""" | ||
**Parameters:** | ||
base_dir: **string** | ||
| The path to the file containing the trained tensorflow model | ||
that the user wishes to visualise with the GUI. Must end | ||
in '/'. | ||
paramnames: **list of strings** | ||
| This should be a list of parameter names in the correct input | ||
order. For example for the released global signal model this | ||
would correspond to | ||
.. code: python | ||
paramnames = [r'$\log(f_*)$', r'$\log(V_c)$', | ||
r'$\log(f_X)$', | ||
r'$\tau$', r'$\alpha$', | ||
r'$\nu_\mathrm{min}$', | ||
r'$R_\mathrm{mfp}$'] | ||
Latex strings can be provided as above. | ||
data_dir: **string** | ||
| The file path to the training and test data which is used to set | ||
the y lims of the GUI graph and ranges/intervals of GUI | ||
sliders. | ||
**Kwargs:** | ||
xHI: **Bool / default: False** | ||
| If True then ``globalemu`` will act as if it is evaluating a | ||
neutral fraction history emulator. | ||
logs: **list / default: [0, 1, 2]** | ||
| The indices corresponding to the astrophysical | ||
parameters that | ||
were logged during training. The default assumes | ||
that the first three columns in "train_data.txt" are | ||
:math:`{f_*}` (star formation efficiency), | ||
:math:`{V_c}` (minimum virial circular velocity) and | ||
:math:`{f_x}` (X-ray efficieny). | ||
""" | ||
|
||
def __init__(self, base_dir, paramnames, data_dir, **kwargs): | ||
|
||
for key, values in kwargs.items(): | ||
if key not in set( | ||
['xHI', 'logs']): | ||
raise KeyError("Unexpected keyword argument in process()") | ||
|
||
self.base_dir = base_dir | ||
self.paramnames = paramnames | ||
self.data_dir = data_dir | ||
self.logs = kwargs.pop('logs', [0, 1, 2]) | ||
self.xHI = kwargs.pop('xHI', False) | ||
|
||
file_kwargs = [self.base_dir, self.data_dir] | ||
file_strings = ['base_dir', 'data_dir'] | ||
for i in range(len(file_kwargs)): | ||
if type(file_kwargs[i]) is not str: | ||
raise TypeError("'" + file_strings[i] + "' must be a sting.") | ||
elif file_kwargs[i].endswith('/') is False: | ||
raise KeyError("'" + file_strings[i] + "' must end with '/'.") | ||
|
||
if type(self.paramnames) is not list: | ||
raise TypeError("'paramnames' must be a list of strings.") | ||
|
||
if type(self.xHI) is not bool: | ||
raise TypeError("'xHI' must be a bool.") | ||
|
||
if type(self.logs) is not list: | ||
raise TypeError("'logs' must be a list.") | ||
|
||
test_data = np.loadtxt(data_dir + 'test_data.txt') | ||
test_labels = np.loadtxt(data_dir + 'test_labels.txt') | ||
for i in range(test_data.shape[1]): | ||
if i in self.logs: | ||
for j in range(test_data.shape[0]): | ||
if test_data[j, i] == 0: | ||
test_data[j, i] = 1e-6 | ||
test_data[:, i] = np.log10(test_data[:, i]) | ||
|
||
data_mins = test_data.min(axis=0) | ||
data_maxs = test_data.max(axis=0) | ||
|
||
full_logs = [] | ||
for i in range(len(data_maxs)): | ||
if i in set(self.logs): | ||
full_logs.append(i) | ||
else: | ||
full_logs.append('--') | ||
|
||
df = pd.DataFrame({'names': self.paramnames, | ||
'mins': data_mins, | ||
'maxs': data_maxs, | ||
'label_min': | ||
[test_labels.min()] + ['']*(len(data_maxs)-1), | ||
'label_max': | ||
[test_labels.max()] + ['']*(len(data_maxs)-1), | ||
'logs': full_logs, | ||
'xHI': | ||
[self.xHI] + ['']*(len(data_maxs)-1)}) | ||
|
||
df.to_csv(base_dir + 'gui_configuration.csv', index=False) |
Oops, something went wrong.