diff --git a/.flake8 b/.flake8 index 914c021..b50e23b 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] -ignore = E8, E402, W503 # Add other codes you want to ignore -max-line-length = 88 # Adjust line length as needed +ignore = E801,E802,E803,E402,W503 +max-line-length = 88 exclude = .git,__pycache__,dist,build # Exclude directories as needed diff --git a/quantum_deep_learning/quantum_generative_models.py b/NeuroFlex/NeuroFlex/quantum_deep_learning/quantum_generative_models.py similarity index 100% rename from quantum_deep_learning/quantum_generative_models.py rename to NeuroFlex/NeuroFlex/quantum_deep_learning/quantum_generative_models.py diff --git a/quantum_deep_learning/quantum_reinforcement_learning.py b/NeuroFlex/NeuroFlex/quantum_deep_learning/quantum_reinforcement_learning.py similarity index 100% rename from quantum_deep_learning/quantum_reinforcement_learning.py rename to NeuroFlex/NeuroFlex/quantum_deep_learning/quantum_reinforcement_learning.py diff --git a/quantum_deep_learning/variational_quantum_circuit.py b/NeuroFlex/NeuroFlex/quantum_deep_learning/variational_quantum_circuit.py similarity index 100% rename from quantum_deep_learning/variational_quantum_circuit.py rename to NeuroFlex/NeuroFlex/quantum_deep_learning/variational_quantum_circuit.py diff --git a/NeuroFlex/advanced_models/multi_modal_learning.py b/NeuroFlex/advanced_models/multi_modal_learning.py index 3176ada..f85150e 100644 --- a/NeuroFlex/advanced_models/multi_modal_learning.py +++ b/NeuroFlex/advanced_models/multi_modal_learning.py @@ -154,6 +154,7 @@ def fuse_modalities(self, encoded_modalities: Dict[str, torch.Tensor]) -> torch. def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward pass through the multi-modal learning model.""" logger.debug(f"Input types: {[(name, type(tensor)) for name, tensor in inputs.items()]}") + logger.debug(f"Input shapes: {[(name, tensor.shape) for name, tensor in inputs.items()]}") if not inputs: raise ValueError("Input dictionary is empty") @@ -161,23 +162,25 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: if len(inputs) == 1: raise ValueError("At least two modalities are required for fusion") - # Ensure all inputs are tensors + # Ensure all inputs are tensors and have correct dtype for name, tensor in inputs.items(): if not isinstance(tensor, torch.Tensor): - inputs[name] = torch.tensor(tensor, dtype=torch.float32) - inputs[name] = inputs[name].float() # Ensure all inputs are float tensors - logger.debug(f"Input {name} shape: {inputs[name].shape}, type: {type(inputs[name])}") + inputs[name] = torch.tensor(tensor) + if name == 'text': + inputs[name] = inputs[name].long() + else: + inputs[name] = inputs[name].float() + logger.debug(f"Input {name} shape: {inputs[name].shape}, type: {inputs[name].dtype}") # Check for batch size consistency across all input modalities batch_sizes = [tensor.size(0) for tensor in inputs.values()] if len(set(batch_sizes)) > 1: raise ValueError(f"Inconsistent batch sizes across modalities: {dict(zip(inputs.keys(), batch_sizes))}") - # Handle individual modality inputs - if set(inputs.keys()) != set(self.modalities.keys()): - missing_modalities = set(self.modalities.keys()) - set(inputs.keys()) - for modality in missing_modalities: - inputs[modality] = torch.zeros((batch_sizes[0],) + self.modalities[modality]['input_shape'], dtype=torch.float32) + # Handle missing modalities + for modality in set(self.modalities.keys()) - set(inputs.keys()): + inputs[modality] = torch.zeros((batch_sizes[0],) + self.modalities[modality]['input_shape'], dtype=torch.float32) + logger.debug(f"Created zero tensor for missing modality {modality}: shape {inputs[modality].shape}") max_batch_size = batch_sizes[0] @@ -193,73 +196,66 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: if inputs[name].shape[1:] != modality['input_shape']: raise ValueError(f"Input shape for {name} {inputs[name].shape} does not match the defined shape (batch_size, {modality['input_shape']})") - if name == 'image': - # For image modality, preserve the 4D structure - encoded_modalities[name] = modality['encoder'](inputs[name]) - logger.debug(f"Encoded image shape: {encoded_modalities[name].shape}") - elif name == 'text': - # For text modality, ensure long type for embedding and float type for LSTM - text_input = inputs[name].long().clamp(0, 29999) # Clamp to valid range - logger.debug(f"Text input shape: {text_input.shape}, type: {type(text_input)}") - embedded = modality['encoder'][0](text_input) - logger.debug(f"Embedded shape: {embedded.shape}, type: {type(embedded)}") - lstm_out, _ = modality['encoder'][1](embedded.float()) # Unpack LSTM output - logger.debug(f"Raw LSTM output shape: {lstm_out.shape}, type: {type(lstm_out)}") - lstm_out = lstm_out[:, -1, :] # Use last time step output - logger.debug(f"LSTM output shape: {lstm_out.shape}, type: {type(lstm_out)}") - lstm_out = lstm_out.contiguous().view(lstm_out.size(0), -1) # Ensure correct shape - logger.debug(f"Reshaped LSTM output shape: {lstm_out.shape}, type: {type(lstm_out)}") - encoded_modalities[name] = modality['encoder'][2](lstm_out) - logger.debug(f"Encoded text shape: {encoded_modalities[name].shape}, type: {type(encoded_modalities[name])}") - elif name == 'time_series': - # For time series, ensure 3D input (batch_size, channels, sequence_length) - if inputs[name].dim() == 2: - inputs[name] = inputs[name].unsqueeze(1) - logger.debug(f"Time series input shape: {inputs[name].shape}, type: {type(inputs[name])}") - encoded_modalities[name] = modality['encoder'](inputs[name]) - logger.debug(f"Encoded time series shape: {encoded_modalities[name].shape}") - elif name == 'tabular': - # For tabular data, ensure 2D input (batch_size, features) - logger.debug(f"Tabular input shape: {inputs[name].shape}, type: {type(inputs[name])}") - encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1)) - logger.debug(f"Encoded tabular shape: {encoded_modalities[name].shape}") - else: - # For other modalities, flatten the input - logger.debug(f"Other modality input shape: {inputs[name].shape}, type: {type(inputs[name])}") - encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1)) - logger.debug(f"Encoded other modality shape: {encoded_modalities[name].shape}") + try: + if name == 'image': + # For image modality, preserve the 4D structure + encoded_modalities[name] = modality['encoder'](inputs[name]) + elif name == 'text': + # For text modality, ensure long type for embedding and float type for LSTM + text_input = inputs[name].long().clamp(0, 29999) # Clamp to valid range + embedded = modality['encoder'][0](text_input) + lstm_out, _ = modality['encoder'][1](embedded.float()) + lstm_out = lstm_out[:, -1, :] # Use last time step output + encoded_modalities[name] = modality['encoder'][2](lstm_out) + elif name == 'time_series': + # For time series, ensure 3D input (batch_size, channels, sequence_length) + if inputs[name].dim() == 2: + inputs[name] = inputs[name].unsqueeze(1) + encoded_modalities[name] = modality['encoder'](inputs[name]) + elif name == 'tabular': + # For tabular data, ensure 2D input (batch_size, features) + encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1)) + else: + # For other modalities, flatten the input + encoded_modalities[name] = modality['encoder'](inputs[name].view(inputs[name].size(0), -1)) - logger.debug(f"Encoded {name} shape: {encoded_modalities[name].shape}, type: {type(encoded_modalities[name])}") + logger.debug(f"Encoded {name} shape: {encoded_modalities[name].shape}, type: {encoded_modalities[name].dtype}") + except Exception as e: + logger.error(f"Error processing modality {name}: {str(e)}") + raise # Ensure all encoded modalities have the same batch size and are 2D tensors encoded_modalities = {name: tensor.view(max_batch_size, -1) for name, tensor in encoded_modalities.items()} logger.debug(f"Encoded modalities shapes after reshaping: {[(name, tensor.shape) for name, tensor in encoded_modalities.items()]}") - if self.fusion_method == 'concatenation': - fused = torch.cat(list(encoded_modalities.values()), dim=1) - elif self.fusion_method == 'attention': - fused = self.fuse_modalities(encoded_modalities) - else: - raise ValueError(f"Unsupported fusion method: {self.fusion_method}") + try: + if self.fusion_method == 'concatenation': + fused = torch.cat(list(encoded_modalities.values()), dim=1) + elif self.fusion_method == 'attention': + fused = self.fuse_modalities(encoded_modalities) + else: + raise ValueError(f"Unsupported fusion method: {self.fusion_method}") - logger.debug(f"Fused tensor shape: {fused.shape}, type: {type(fused)}") + logger.debug(f"Fused tensor shape: {fused.shape}, type: {fused.dtype}") - # Ensure fused tensor is 2D and matches the classifier's input size - if fused.dim() != 2 or fused.size(1) != self.classifier.in_features: - fused = fused.view(max_batch_size, -1) - fused = nn.functional.adaptive_avg_pool1d(fused.unsqueeze(1), self.classifier.in_features).squeeze(1) + # Ensure fused tensor is 2D and matches the classifier's input size + if fused.dim() != 2 or fused.size(1) != self.classifier.in_features: + fused = fused.view(max_batch_size, -1) + fused = nn.functional.adaptive_avg_pool1d(fused.unsqueeze(1), self.classifier.in_features).squeeze(1) - # Ensure fused tensor is a valid input for the classifier - fused = fused.float() # Convert to float if not already + # Ensure fused tensor is a valid input for the classifier + fused = fused.float() # Convert to float if not already - logger.debug(f"Final fused tensor shape: {fused.shape}, type: {type(fused)}") + logger.debug(f"Final fused tensor shape: {fused.shape}, type: {fused.dtype}") + logger.debug(f"Classifier input shape: {fused.shape}") - # Ensure input to classifier is a tensor - if not isinstance(fused, torch.Tensor): - fused = torch.tensor(fused, dtype=torch.float32) + output = self.classifier(fused) + logger.debug(f"Final output shape: {output.shape}") - logger.debug(f"Classifier input shape: {fused.shape}, type: {type(fused)}") - return self.classifier(fused) + return output + except Exception as e: + logger.error(f"Error during fusion or classification: {str(e)}") + raise def fit(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, val_data: Dict[str, torch.Tensor] = None, val_labels: torch.Tensor = None, epochs: int = 10, lr: float = 0.001, patience: int = 5, batch_size: int = 32): """Train the multi-modal learning model.""" @@ -415,6 +411,10 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti correct_predictions = 0 total_samples = 0 + logger.debug(f"_train_epoch input - data types: {[(k, type(v)) for k, v in data.items()]}") + logger.debug(f"_train_epoch input - data shapes: {[(k, v.shape) for k, v in data.items()]}") + logger.debug(f"_train_epoch input - labels type: {type(labels)}, shape: {labels.shape}") + if not data or not labels.numel(): logger.warning("Empty data or labels provided for training epoch.") return 0.0, 0.0 @@ -425,6 +425,10 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti num_batches = (num_samples + batch_size - 1) // batch_size for i, (batch_data, batch_labels) in enumerate(self._batch_data(data, labels, batch_size)): + logger.debug(f"Batch {i+1}/{num_batches} - data types: {[(k, type(v)) for k, v in batch_data.items()]}") + logger.debug(f"Batch {i+1}/{num_batches} - data shapes: {[(k, v.shape) for k, v in batch_data.items()]}") + logger.debug(f"Batch {i+1}/{num_batches} - labels type: {type(batch_labels)}, shape: {batch_labels.shape}") + if not batch_data or not batch_labels.numel(): logger.warning(f"Empty batch encountered at iteration {i+1}/{num_batches}. Skipping.") continue @@ -432,7 +436,9 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti print(f"\rTraining: {i+1}/{num_batches}", end="", flush=True) optimizer.zero_grad() outputs = self.forward(batch_data) + logger.debug(f"Batch {i+1}/{num_batches} - outputs type: {type(outputs)}, shape: {outputs.shape}") loss = criterion(outputs, batch_labels) + logger.debug(f"Batch {i+1}/{num_batches} - loss type: {type(loss)}, value: {loss.item()}") loss.backward() # Gradient clipping @@ -465,6 +471,7 @@ def _train_epoch(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, opti except Exception as e: logger.error(f"Error in _train_epoch: {str(e)}") + logger.exception("Traceback:") return 0.0, 0.0 def _validate(self, data: Dict[str, torch.Tensor], labels: torch.Tensor, criterion: nn.Module, batch_size: int = 32) -> Tuple[float, float]: diff --git a/NeuroFlex/cognitive_architectures/global_workspace_theory.py b/NeuroFlex/cognitive_architectures/global_workspace_theory.py index 29a1c03..5fdd7ad 100644 --- a/NeuroFlex/cognitive_architectures/global_workspace_theory.py +++ b/NeuroFlex/cognitive_architectures/global_workspace_theory.py @@ -31,7 +31,7 @@ import jax import jax.numpy as jnp from flax import linen as nn -from flax.core.frozen_dict import FrozenDict +from flax.core import FrozenDict, freeze, unfreeze from jax import random from collections.abc import Callable @@ -40,89 +40,71 @@ class GWTModel(nn.Module): workspace_size: int def setup(self): - self.specialized_processes = [nn.Dense(self.workspace_size, kernel_init=nn.initializers.normal(stddev=0.01)) for _ in range(self.num_processes)] - self.global_workspace = nn.Dense(self.workspace_size, kernel_init=nn.initializers.normal(stddev=0.01)) self.weights = self.param('weights', nn.initializers.uniform(), (self.num_processes,)) - self.consciousness_layer = nn.Dense(self.workspace_size, kernel_init=nn.initializers.normal(stddev=0.01)) - self.bias_mitigation_layer = nn.Dense(self.workspace_size, kernel_init=nn.initializers.normal(stddev=0.01)) + self.specialized_processes = [nn.Dense(self.workspace_size, name=f'specialized_process_{i}') for i in range(self.num_processes)] + self.global_workspace = nn.Dense(self.workspace_size, name='global_workspace') + self.consciousness_layer = nn.Dense(self.workspace_size, name='consciousness_layer') + self.bias_mitigation_layer = nn.Dense(self.workspace_size, name='bias_mitigation_layer') + @nn.compact def __call__(self, inputs): - # Process inputs without assuming specific attributes - if callable(inputs): - return inputs(self) - inputs = jnp.atleast_2d(inputs) # Ensure inputs are at least 2D + # Process inputs + inputs = jnp.atleast_2d(inputs) # Ensure inputs are at least 2d + specialized_outputs = [process(inputs) for process in self.specialized_processes] - integrated_workspace = self.integrate_workspace(specialized_outputs) + integrated_workspace = self.integrate_workspace(specialized_outputs, self.global_workspace, self.weights) broadcasted_workspace = self.broadcast_workspace(integrated_workspace, specialized_outputs) - return broadcasted_workspace, integrated_workspace - def integrate_workspace(self, specialized_outputs): + # Apply GWT formula + gwt_output = jnp.sum(jnp.sum(jnp.stack(specialized_outputs) * self.weights[:, jnp.newaxis], axis=0)) + + return freeze({ + 'params': self.variables['params'], + 'broadcasted_workspace': broadcasted_workspace, + 'integrated_workspace': integrated_workspace, + 'gwt_output': jnp.array([gwt_output]) # Ensure gwt_output has shape (1,) + }) + + def integrate_workspace(self, specialized_outputs, global_workspace, weights): """ Integrate information from specialized processes into the global workspace. """ - weights = self.variables['params']['weights'] weighted_sum = jnp.sum(jnp.stack(specialized_outputs) * weights[:, jnp.newaxis], axis=0) - integrated = self.global_workspace(weighted_sum) - print(f"Integrated shape: {integrated.shape}, Specialized output shape: {specialized_outputs[0].shape}") - print(f"Weights shape: {weights.shape}, Weighted sum shape: {weighted_sum.shape}") + integrated = global_workspace(weighted_sum) # Ensure the integrated output has shape (1, workspace_size) integrated = integrated.mean(axis=0, keepdims=True) # Average across processes and keep dims - print(f"Final integrated shape: {integrated.shape}") return integrated + def update_weights(self, new_weights): + """ + Update the weights for each specialized process. + """ + if new_weights.shape != (self.num_processes,): + raise ValueError("Number of weights must match number of processes") + normalized_weights = new_weights / jnp.sum(new_weights) + new_variables = unfreeze(self.variables) + new_variables['params']['weights'] = normalized_weights + return freeze(new_variables) + def broadcast_workspace(self, integrated_workspace, specialized_outputs): """ Broadcast the contents of the global workspace to all specialized processes. """ - print("Integrated workspace shape:", integrated_workspace.shape) - print("Specialized outputs shapes:", [output.shape for output in specialized_outputs]) broadcasted = [jnp.broadcast_to(integrated_workspace[i], output.shape) for i, output in enumerate(specialized_outputs)] return broadcasted - def apply_gwt_formula(self, input_stimulus): - """ - Apply the GWT formula: G(x) = sum(w_i * f(x_i)) - - Args: - input_stimulus (jax.numpy.array): Input stimulus to the model. - - Returns: - jax.numpy.array: The result of applying the GWT formula. - """ - specialized_outputs = [process(input_stimulus) for process in self.specialized_processes] - weights = self.variables['params']['weights'] - return jnp.sum(jnp.stack(specialized_outputs) * weights[:, jnp.newaxis], axis=0) - - def update_weights(self, new_weights): - """ - Update the weights for each specialized process. - - Args: - new_weights (jax.numpy.array or callable): New weights for the specialized processes or a function to update them. - """ - current_weights = self.variables['params']['weights'] - if callable(new_weights): - updated_weights = new_weights(current_weights) - else: - updated_weights = new_weights - - if isinstance(updated_weights, jnp.ndarray): - if updated_weights.shape != (self.num_processes,): - raise ValueError("Number of weights must match number of processes") - normalized_weights = updated_weights / jnp.sum(updated_weights) # Normalize weights - else: - raise ValueError("Updated weights must be a JAX numpy array") - - # Return a new FrozenDict with updated weights - return FrozenDict({'params': {'weights': normalized_weights}}) - @property def current_weights(self): return self.variables['params']['weights'] - @property - def current_weights(self): - return self.variables['params']['weights'] +# Example usage: +if __name__ == "__main__": + key = random.PRNGKey(0) + model = GWTModel(num_processes=5, workspace_size=100) + x = random.normal(key, (1, 100)) + variables = model.init(key, x) + y = model.apply(variables, x) + print(y) # Example usage: if __name__ == "__main__": @@ -134,15 +116,12 @@ def current_weights(self): # Generate some dummy input data inputs = jnp.array(np.random.randn(1, 100)) - # Initialize parameters + # Initialize the model parameters key = jax.random.PRNGKey(0) params = model.init(key, inputs) # Run the model - broadcasted_workspace, integrated_workspace = model.apply(params, inputs) - print("Broadcasted workspace shape:", [bw.shape for bw in broadcasted_workspace]) - print("Integrated workspace shape:", integrated_workspace.shape) - - # Apply GWT formula - gwt_output = model.apply(params, inputs, method=model.apply_gwt_formula) - print("GWT formula output shape:", gwt_output.shape) + output = model.apply(params, inputs) + print("Broadcasted workspace shape:", [bw.shape for bw in output['broadcasted_workspace']]) + print("Integrated workspace shape:", output['integrated_workspace'].shape) + print("GWT formula output shape:", output['gwt_output'].shape) diff --git a/quantum_consciousness/documentation.md b/NeuroFlex/quantum_consciousness/documentation.md similarity index 100% rename from quantum_consciousness/documentation.md rename to NeuroFlex/quantum_consciousness/documentation.md diff --git a/quantum_consciousness/orch_or_simulation.py b/NeuroFlex/quantum_consciousness/orch_or_simulation.py similarity index 100% rename from quantum_consciousness/orch_or_simulation.py rename to NeuroFlex/quantum_consciousness/orch_or_simulation.py diff --git a/quantum_consciousness/quantum_mind_hypothesis_simulation.py b/NeuroFlex/quantum_consciousness/quantum_mind_hypothesis_simulation.py similarity index 100% rename from quantum_consciousness/quantum_mind_hypothesis_simulation.py rename to NeuroFlex/quantum_consciousness/quantum_mind_hypothesis_simulation.py diff --git a/quantum_consciousness/quantum_theories_simulation.py b/NeuroFlex/quantum_consciousness/quantum_theories_simulation.py similarity index 100% rename from quantum_consciousness/quantum_theories_simulation.py rename to NeuroFlex/quantum_consciousness/quantum_theories_simulation.py diff --git a/NeuroFlex/quantum_deep_learning/quantum_boltzmann_machine.py b/NeuroFlex/quantum_deep_learning/quantum_boltzmann_machine.py index 8aa6726..0a35736 100644 --- a/NeuroFlex/quantum_deep_learning/quantum_boltzmann_machine.py +++ b/NeuroFlex/quantum_deep_learning/quantum_boltzmann_machine.py @@ -53,10 +53,11 @@ def energy(self, visible_state, hidden_state): if entangled_state.ndim == 1 and entangled_state.shape[0] >= 4: # Use the absolute value of the last element of entangled_state as the interaction strength interaction_strength = abs(float(entangled_state[-1])) - energy += interaction_strength * float(visible_state[i]) * float(hidden_state[j]) + energy -= interaction_strength * float(visible_state[i]) * float(hidden_state[j]) # Negate the interaction term else: raise ValueError(f"Unexpected shape of entangled_state: {entangled_state.shape}") - return float(-energy) # Return negative energy as float to align with minimization objective + print(f"Energy calculation: visible_state={visible_state}, hidden_state={hidden_state}, energy={energy}") + return float(energy) # Return non-positive energy as float (energy is already negative or zero) def sample_hidden(self, visible_state): hidden_probs = np.zeros(self.num_hidden) diff --git a/NeuroFlex/quantum_neural_networks/quantum_module.py b/NeuroFlex/quantum_neural_networks/quantum_module.py index b7d9263..4184033 100644 --- a/NeuroFlex/quantum_neural_networks/quantum_module.py +++ b/NeuroFlex/quantum_neural_networks/quantum_module.py @@ -20,9 +20,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +print("Attempting to import jax...") import jax +print(f"JAX version: {jax.__version__}") +print(f"JAX location: {jax.__file__}") +print("Attempting to import jax.numpy...") import jax.numpy as jnp -from jax.config import config +print("Attempting to import jax.config...") +from jax import config +print("Successfully imported jax.config") config.update("jax_enable_x64", True) import flax.linen as nn from typing import List, Tuple, Dict, Any diff --git a/tests/cognitive_architectures/test_cognitive_models.py b/tests/cognitive_architectures/test_cognitive_models.py index 07e171c..a7283bd 100644 --- a/tests/cognitive_architectures/test_cognitive_models.py +++ b/tests/cognitive_architectures/test_cognitive_models.py @@ -1,33 +1,32 @@ import pytest import jax.numpy as jnp from jax import random -from flax import linen as nn -from flax.training import train_state -from flax.training.train_state import TrainState -import optax -from flax.core.frozen_dict import FrozenDict -from flax.core import freeze, unfreeze from NeuroFlex.cognitive_architectures.attention_schema_theory import ASTModel from NeuroFlex.cognitive_architectures.global_workspace_theory import GWTModel from NeuroFlex.cognitive_architectures.higher_order_thoughts import HOTModel from NeuroFlex.cognitive_architectures.integrated_information_theory import IITModel + @pytest.fixture def ast_model(): return ASTModel(attention_dim=10, hidden_dim=20) + @pytest.fixture def gwt_model(): return GWTModel(num_processes=5, workspace_size=100) + @pytest.fixture def hot_model(): return HOTModel(num_layers=3, hidden_dim=10) + @pytest.fixture def iit_model(): return IITModel(num_components=5) + def test_ast_model(ast_model): key = random.PRNGKey(0) x = random.normal(key, (1, 10)) @@ -37,53 +36,76 @@ def test_ast_model(ast_model): assert output.shape == (1, 10) # Updated to match the input shape assert jnp.isfinite(output).all() + def test_gwt_model(gwt_model): key = random.PRNGKey(0) - input_stimulus = random.normal(key, (1, 100)) - - variables = gwt_model.init({"params": key}, input_stimulus) - # Ensure the correct structure is used for the apply function - assert isinstance(variables['params'], FrozenDict), "variables['params'] should be a FrozenDict" + input_stimulus = random.normal(key, (1, gwt_model.workspace_size)) - state = train_state.TrainState.create( - apply_fn=gwt_model.apply, - params=variables['params'], - tx=optax.adam(learning_rate=1e-3) + variables = gwt_model.init(key, input_stimulus) + assert 'params' in variables, "variables should contain 'params'" + assert 'weights' in variables['params'], "variables['params'] should contain 'weights'" + assert variables['params']['weights'].shape == (gwt_model.num_processes,), ( + f"weights shape should be ({gwt_model.num_processes},)" ) - # Ensure the model is bound before accessing variables bound_gwt_model = gwt_model.bind(variables) - # Pass the PRNG key correctly during model application - broadcasted_workspace, integrated_workspace = bound_gwt_model.apply(variables, input_stimulus, rngs={"params": key}) + output = bound_gwt_model.apply(variables, input_stimulus) - assert len(broadcasted_workspace) == gwt_model.num_processes - assert all(bw.shape == (1, gwt_model.workspace_size) for bw in broadcasted_workspace) - assert integrated_workspace.shape == (1, gwt_model.workspace_size) - assert all(jnp.isfinite(bw).all() for bw in broadcasted_workspace) - assert jnp.isfinite(integrated_workspace).all() + assert 'broadcasted_workspace' in output, "Output should contain 'broadcasted_workspace'" + assert 'integrated_workspace' in output, "Output should contain 'integrated_workspace'" + assert 'gwt_output' in output, "Output should contain 'gwt_output'" - # Verify weights initialization - assert 'weights' in variables['params'] - assert variables['params']['weights'].shape == (gwt_model.num_processes,) + broadcasted_workspace = output['broadcasted_workspace'] + integrated_workspace = output['integrated_workspace'] + gwt_output = output['gwt_output'] - # Investigate the GWT model's weight update issue + assert len(broadcasted_workspace) == gwt_model.num_processes, ( + f"Expected {gwt_model.num_processes} broadcasted workspaces" + ) + assert all( + bw.shape == (1, gwt_model.workspace_size) for bw in broadcasted_workspace + ), ( + f"Each broadcasted workspace should have shape (1, {gwt_model.workspace_size})" + ) + assert integrated_workspace.shape == (1, gwt_model.workspace_size), ( + f"Integrated workspace should have shape (1, {gwt_model.workspace_size})" + ) + assert gwt_output.shape == (1,), "GWT output should have shape (1,)" + assert all(jnp.isfinite(bw).all() for bw in broadcasted_workspace), ( + "Broadcasted workspaces contain non-finite values" + ) + assert jnp.isfinite(integrated_workspace).all(), ( + "Integrated workspace contains non-finite values" + ) + assert jnp.isfinite(gwt_output).all(), "GWT output contains non-finite values" + + # Test weight update new_weights = jnp.array([0.1, 0.2, 0.3, 0.2, 0.2]) - print(f"Initial weights: {variables['params']['weights']}") - updated_variables = bound_gwt_model.apply({'params': variables['params']}, new_weights, method=bound_gwt_model.update_weights, rngs={'params': key}) + assert new_weights.shape == (gwt_model.num_processes,), ( + f"New weights shape should be ({gwt_model.num_processes},)" + ) + + initial_weights = variables['params']['weights'] + print(f"Initial weights: {initial_weights[:3]}...") # Show only first 3 elements + + updated_variables = bound_gwt_model.apply( + variables, new_weights, method=bound_gwt_model.update_weights + ) updated_weights = updated_variables['params']['weights'] expected_weights = new_weights / jnp.sum(new_weights) + print(f"Updated weights: {updated_weights}") print(f"Expected weights: {expected_weights}") print(f"Difference: {jnp.abs(updated_weights - expected_weights)}") - assert jnp.allclose(updated_weights, expected_weights, atol=1e-5) -def test_ast_model(ast_model): - key = random.PRNGKey(0) - x = random.normal(key, (1, 10)) # Input shape (1, 10) - variables = ast_model.init(key, x) - output = ast_model.apply(variables, x) - assert output.shape == (1, 10) # Ensure output shape matches input shape + assert jnp.allclose(updated_weights, expected_weights, atol=1e-5), ( + "Weight update did not produce expected results" + ) + assert jnp.isclose(jnp.sum(updated_weights), 1.0, atol=1e-5), ( + "Updated weights should sum to 1" + ) + def test_hot_model(hot_model): key = random.PRNGKey(0) @@ -91,17 +113,25 @@ def test_hot_model(hot_model): params = hot_model.init(key, x) output = hot_model.apply(params, x) - assert output.shape == (1, hot_model.output_dim), f"Expected output shape (1, {hot_model.output_dim}), but got {output.shape}" + assert output.shape == (1, hot_model.output_dim), ( + f"Expected output shape (1, {hot_model.output_dim}), but got {output.shape}" + ) assert jnp.isfinite(output).all(), "Output contains non-finite values" # Verify the model's dimensions - assert hot_model.input_dim == hot_model.output_dim, f"Expected input_dim to match output_dim, but got input_dim={hot_model.input_dim} and output_dim={hot_model.output_dim}" + assert hot_model.input_dim == hot_model.output_dim, ( + f"Expected input_dim to match output_dim, but got input_dim={hot_model.input_dim} " + f"and output_dim={hot_model.output_dim}" + ) assert hot_model.hidden_dim == 10, f"Expected hidden_dim 10, but got {hot_model.hidden_dim}" - print(f"HOT model dimensions: input_dim={hot_model.input_dim}, hidden_dim={hot_model.hidden_dim}, output_dim={hot_model.output_dim}") + print( + f"HOT model dimensions: input_dim={hot_model.input_dim}, " + f"hidden_dim={hot_model.hidden_dim}, output_dim={hot_model.output_dim}" + ) + def test_iit_model(iit_model): key = random.PRNGKey(0) - state = random.normal(key, (5,)) # Initialize the model params = iit_model.init(key, None) @@ -113,6 +143,7 @@ def test_iit_model(iit_model): assert jnp.isfinite(phi) assert phi >= 0 + def test_ast_model_training(ast_model): key = random.PRNGKey(0) x = random.normal(key, (10, 10)) # Batch size of 10, input dimension of 10 @@ -130,6 +161,7 @@ def test_ast_model_training(ast_model): assert jnp.isfinite(loss) assert loss >= 0 + def test_gwt_model_update_weights(gwt_model): key = random.PRNGKey(0) input_stimulus = random.normal(key, (1, gwt_model.workspace_size)) @@ -139,10 +171,18 @@ def test_gwt_model_update_weights(gwt_model): new_weights = jnp.array([0.1, 0.2, 0.3, 0.2, 0.2]) assert new_weights.shape == (gwt_model.num_processes,), "New weights shape mismatch" - updated_variables = bound_gwt_model.apply({'params': variables['params']}, new_weights, method=bound_gwt_model.update_weights, rngs={'params': key}) + updated_variables = bound_gwt_model.apply( + variables, new_weights, method=bound_gwt_model.update_weights + ) updated_weights = updated_variables['params']['weights'] expected_weights = new_weights / jnp.sum(new_weights) - assert jnp.allclose(updated_weights, expected_weights, atol=1e-5), f"Expected {expected_weights}, but got {updated_weights}" + assert jnp.allclose(updated_weights, expected_weights, atol=1e-5), ( + f"Expected {expected_weights}, but got {updated_weights}" + ) + assert jnp.isclose(jnp.sum(updated_weights), 1.0, atol=1e-5), ( + "Updated weights should sum to 1" + ) + def test_hot_model_higher_order_thought(hot_model): key = random.PRNGKey(0) @@ -150,20 +190,23 @@ def test_hot_model_higher_order_thought(hot_model): params = hot_model.init(key, x) first_order_thought = hot_model.apply(params, x) - higher_order_thought = hot_model.generate_higher_order_thought(params, first_order_thought) + higher_order_thought = hot_model.generate_higher_order_thought( + params, first_order_thought + ) assert higher_order_thought.shape == (1, hot_model.output_dim) assert jnp.isfinite(higher_order_thought).all() assert first_order_thought.shape == (1, hot_model.output_dim) + def test_iit_model_cause_effect_structure(iit_model): key = random.PRNGKey(0) - state = random.normal(key, (5,)) # Initialize the model - params = iit_model.init(key, state) + params = iit_model.init(key, None) initialized_iit_model = iit_model.bind(params) + state = random.normal(key, (iit_model.num_components,)) ces = initialized_iit_model.compute_cause_effect_structure(state) assert isinstance(ces, dict) assert len(ces) > 0 diff --git a/tests/test_quantum_models.py b/tests/quantum_consciousness/test_quantum_models.py similarity index 100% rename from tests/test_quantum_models.py rename to tests/quantum_consciousness/test_quantum_models.py diff --git a/tests/test_quantum_module.py b/tests/quantum_consciousness/test_quantum_module.py similarity index 100% rename from tests/test_quantum_module.py rename to tests/quantum_consciousness/test_quantum_module.py diff --git a/tests/test_quantum_protein_development.py b/tests/quantum_consciousness/test_quantum_protein_development.py similarity index 100% rename from tests/test_quantum_protein_development.py rename to tests/quantum_consciousness/test_quantum_protein_development.py diff --git a/tests/test_quantum_deep_learning.py b/tests/quantum_deep_learning/test_quantum_deep_learning.py similarity index 100% rename from tests/test_quantum_deep_learning.py rename to tests/quantum_deep_learning/test_quantum_deep_learning.py