Skip to content

Commit

Permalink
Custom loss function (#6)
Browse files Browse the repository at this point in the history
* added code to allow the user to pass their own custom loss functions

* bumping version number

* testing new functionality (custom loss functions)
  • Loading branch information
htjb authored Nov 1, 2021
1 parent 49c1b25 commit a8a2598
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:globalemu: Robust Global 21-cm Signal Emulation
:Author: Harry Thomas Jones Bevins
:Version: 1.4.1
:Version: 1.5.0
:Homepage: https://github.com/htjb/globalemu
:Documentation: https://globalemu.readthedocs.io/

Expand Down
31 changes: 29 additions & 2 deletions globalemu/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ class nn():
documentation for more details on the types of activation
functions available.
loss_function: **Callable/ default: None**
| By default the code uses an MSE loss however users are able to
pass their own loss functions when training the neural
network. These should be functions that take in the true labels
(temperatures) and the predicted labels and return some measure
of loss. Care needs to be taken to ensure that the correct loss
function is supplied when resuming the training of
a previous run as ``globalemu`` will not check this. In order
for the loss function to work it must be built
using the tensorflow.keras backend. An example would be
.. code:: python
from tensorflow.keras import backend as K
def custom_loss(true_labels, predicted_labels):
return K.mean(K.abs(true_labels - predicted_labels))
resume: **Bool / default: False**
| If set to ``True`` then ``globalemu`` will look in the
``base_dir`` for a trained model and ``loss_history.txt``
Expand Down Expand Up @@ -132,7 +150,8 @@ def __init__(self, **kwargs):
'lr', 'dropout', 'input_shape',
'output_shape', 'layer_sizes', 'base_dir',
'early_stop', 'early_stop_lim', 'xHI', 'resume',
'random_seed', 'output_activation']):
'random_seed', 'output_activation',
'loss_function']):
raise KeyError("Unexpected keyword argument in nn()")

self.resume = kwargs.pop('resume', False)
Expand Down Expand Up @@ -189,6 +208,11 @@ def __init__(self, **kwargs):
raise TypeError("'" + float_strings[i] +
"' must be a float.")

loss_function = kwargs.pop('loss_function', None)
if loss_function is not None:
if not callable(loss_function):
raise TypeError('loss_function should be a callable.')

if self.random_seed is not None:
tf.random.set_seed(self.random_seed)

Expand Down Expand Up @@ -232,7 +256,10 @@ def pack_features_vector(features, labels):
def loss(model, x, y, training):
y_ = tf.transpose(model(x, training=training))[0]
lf = loss_functions(y, y_)
return lf.mse(), lf.rmse()
if loss_function is None:
return lf.mse(), lf.rmse()
else:
return loss_function(y, y_), lf.rmse()

def grad(model, inputs, targets):
with tf.GradientTape() as tape:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def readme(short=False):

setup(
name='globalemu',
version='1.4.1',
version='1.5.0',
description='globalemu: Robust and Fast Global 21-cm Signal Emulation',
long_description=readme(),
author='Harry T. J. Bevins',
Expand Down
9 changes: 8 additions & 1 deletion tests/test_network.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import numpy as np
from globalemu.preprocess import process
from globalemu.network import nn
from tensorflow.keras import backend as K
import requests
import os
import shutil
import pytest


def test_process_nn():

def custom_loss(y, y_):
return K.mean(K.abs(y - y_))

z = np.arange(5, 50.1, 0.1)

process(10, z, data_location='21cmGEM_data/')
nn(batch_size=451, layer_sizes=[8], epochs=10)
nn(batch_size=451, layer_sizes=[8], epochs=10, loss_function=custom_loss)

# results of below will not make sense as it is being run on the
# global signal data but it will test the code (xHI data not public)
Expand Down Expand Up @@ -59,6 +64,8 @@ def test_process_nn():
nn(resume=10)
with pytest.raises(TypeError):
nn(output_activation=2)
with pytest.raises(TypeError):
nn(loss_function='foobar')

process(10, z, data_location='21cmGEM_data/', base_dir='base_dir/')
nn(batch_size=451, layer_sizes=[], random_seed=10,
Expand Down

0 comments on commit a8a2598

Please sign in to comment.