-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
83 lines (68 loc) · 2.73 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Import necessary libraries
import jax
import flax.linen as nn
class Config:
def __init__(self):
# NeuroFlex settings
self.neuroflex_features = [64, 32, 10]
self.use_cnn = True
self.use_rnn = True
self.use_gan = True
self.fairness_constraint = 0.1
self.use_quantum = True
self.backend = 'jax' # Choose from 'jax', 'tensorflow', 'pytorch'
# JAX settings
self.jax_seed = 42
self.jax_precision = 'float32'
# Flax settings
self.flax_optimizer = 'adam'
self.flax_learning_rate = 1e-3
# RouteLLM and Ollama settings
self.use_routellm = True
self.ollama_model = 'llama2'
# Llama3 settings
self.use_llama3 = False
self.llama3_model_size = 'base' # Options: 'base', 'large', 'xl'
self.llama3_context_length = 2048
# Gemma2 settings
self.use_gemma2 = False
self.gemma2_model_size = 'base' # Options: 'base', 'medium', 'large'
self.gemma2_attention_heads = 12
# Mistral3 settings
self.use_mistral3 = False
self.mistral3_model_size = 'small' # Options: 'small', 'medium', 'large'
self.mistral3_layers = 12
def display(self):
print("NeuroFlex Configuration:")
print(f"Features: {self.neuroflex_features}")
print(f"Use CNN: {self.use_cnn}")
print(f"Use RNN: {self.use_rnn}")
print(f"Use GAN: {self.use_gan}")
print(f"Fairness Constraint: {self.fairness_constraint}")
print(f"Use Quantum: {self.use_quantum}")
print(f"Backend: {self.backend}")
print("\nJAX Configuration:")
print(f"Seed: {self.jax_seed}")
print(f"Precision: {self.jax_precision}")
print("\nFlax Configuration:")
print(f"Optimizer: {self.flax_optimizer}")
print(f"Learning Rate: {self.flax_learning_rate}")
print("\nRouteLLM and Ollama Configuration:")
print(f"Use RouteLLM: {self.use_routellm}")
print(f"Ollama Model: {self.ollama_model}")
print("\nLlama3 Configuration:")
print(f"Use Llama3: {self.use_llama3}")
print(f"Model Size: {self.llama3_model_size}")
print(f"Context Length: {self.llama3_context_length}")
print("\nGemma2 Configuration:")
print(f"Use Gemma2: {self.use_gemma2}")
print(f"Model Size: {self.gemma2_model_size}")
print(f"Attention Heads: {self.gemma2_attention_heads}")
print("\nMistral3 Configuration:")
print(f"Use Mistral3: {self.use_mistral3}")
print(f"Model Size: {self.mistral3_model_size}")
print(f"Layers: {self.mistral3_layers}")
# Example usage
if __name__ == "__main__":
config = Config()
config.display()