-
Notifications
You must be signed in to change notification settings - Fork 1
/
hyper_para.py
67 lines (51 loc) · 1.93 KB
/
hyper_para.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import numpy as np
############################################
## This code is for initializing the system dimension
## and training (NN) parameters
###########################################
############################################
# set default data type to double; for GPU
# training use float
############################################
torch.set_default_dtype(torch.float64)
torch.set_default_tensor_type(torch.DoubleTensor)
# torch.set_default_dtype(torch.float32)
# torch.set_default_tensor_type(torch.FloatTensor)
FINE_TUNE = 1 # set to 1 for fine-tuning a pre-trained model
############################################
# set the network architecture
############################################
N_H_B = 1 # the number of hidden layers for the barrier
D_H_B = 20 # the number of neurons of each hidden layer for the barrier
###########################################
#Barrier function conditions
#########################################
gamma=0; #first condition <= 0
lamda=0.00001; #this is required for strict inequality >= lambda
############################################
# set loss function definition
############################################
TOL_SAFE = 0.0 # tolerance for safe and unsafe conditions
TOL_UNSAFE = 0.000
TOL_LIE = 0.000 #tolerance for the last condition
TOL_DATA_GEN = 1e-16 #for data generation
############################################
#Lipschitz bound for training
lip_h= 1
lip_dh= 1
lip_d2h= 2
############################################
# number of training epochs
############################################
EPOCHS = 500
############################################
############################################
ALPHA=0.1
BETA = 0 # if beta equals 0 then constant rate = alpha
GAMMA = 0 # when beta is nonzero, larger gamma gives faster drop of rate
#weights for loss function
DECAY_LIE = 1 # decay of lie weight 0.1 works, 1 does not work
DECAY_SAFE = 1
DECAY_UNSAFE = 1