Skip to content

Commit

Permalink
Update WorkingMemory and CognitiveProcessIntegration classes and …
Browse files Browse the repository at this point in the history
…tests

* **models/memory.py**
  - Update `scan_fn` function in `WorkingMemory` class to maintain a consistent carry structure
  - Initialize the carry correctly in the `WorkingMemory` class
  - Ensure `LSTMCell` initialization and carry structure are correctly handled in the `WorkingMemory` class

* **models/consciousness_state.py**
  - Compute attention maps during cross-modal attention computation in the `CognitiveProcessIntegration` class
  - Return attention maps from the `CognitiveProcessIntegration` class

* **tests/benchmarks/test_arc_reasoning.py**
  - Update tests to validate the correct computation and return of attention maps
  • Loading branch information
kasinadhsarma committed Dec 21, 2024
1 parent 22a87ac commit b9d66ca
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
1 change: 1 addition & 0 deletions models/consciousness_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __call__(self, inputs: Dict[str, jnp.ndarray], deterministic: bool = True):
for modality_input in inputs.values():
attended = attention(modality_input, modality_input, mask=mask, deterministic=deterministic)
cross_modal_contexts.append(attended)
attention_maps[f"{target_modality}-{source_modality}"] = attended

# Ensure tensor shapes match before combining
if cross_modal_contexts:
Expand Down
11 changes: 7 additions & 4 deletions models/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,21 @@ def __call__(self, inputs, initial_state=None, mask=None, deterministic=True):

# Process sequence using pure function for JAX compatibility
def scan_fn(carry, x):
h, _ = carry
h_new, y = rnn_cell(h, x)
return (h_new, y), y
h, y = carry
h_new, y_new = rnn_cell(h, x)
return (h_new, y_new), y_new

# Ensure inputs and state are float32
inputs = jnp.asarray(inputs, dtype=jnp.float32)
initial_state = jnp.asarray(initial_state, dtype=jnp.float32)

# Initialize carry correctly
initial_carry = (initial_state, jnp.zeros((batch_size, self.hidden_dim)))

# Use scan with explicit axis for sequence processing
(final_state, _), outputs = jax.lax.scan(
scan_fn,
init=(initial_state, None),
init=initial_carry,
xs=inputs.swapaxes(0, 1)
)
outputs = outputs.swapaxes(0, 1)
Expand Down
22 changes: 22 additions & 0 deletions tests/benchmarks/test_arc_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def test_pattern_recognition(self, key, consciousness_model):
assert 'attention_weights' in metrics
assert metrics['attention_weights'].ndim >= 3 # (batch, heads, seq)

# Validate attention maps
assert 'attention_maps' in metrics
for attn_map in metrics['attention_maps'].values():
assert jnp.allclose(
jnp.sum(attn_map, axis=-1),
jnp.ones((batch_size, 8, 64)) # (batch, heads, seq_length)
)

except Exception as e:
pytest.fail(f"Pattern recognition test failed: {str(e)}")

Expand Down Expand Up @@ -161,5 +169,19 @@ def test_conscious_adaptation(self, key, consciousness_model):
assert 'attention_weights' in simple_metrics
assert 'attention_weights' in complex_metrics

# Validate attention maps
assert 'attention_maps' in simple_metrics
assert 'attention_maps' in complex_metrics
for attn_map in simple_metrics['attention_maps'].values():
assert jnp.allclose(
jnp.sum(attn_map, axis=-1),
jnp.ones((batch_size, 8, 64)) # (batch, heads, seq_length)
)
for attn_map in complex_metrics['attention_maps'].values():
assert jnp.allclose(
jnp.sum(attn_map, axis=-1),
jnp.ones((batch_size, 8, 64)) # (batch, heads, seq_length)
)

except Exception as e:
pytest.fail(f"Conscious adaptation test failed: {str(e)}")

0 comments on commit b9d66ca

Please sign in to comment.