Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/syna/functions/function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional

import numpy as np

Expand All @@ -13,7 +13,7 @@
class Reshape(Function):
"""Reshape tensor to a given shape."""

def __init__(self, shape: Tuple[int, ...]) -> None:
def __init__(self, shape: tuple[int, ...]) -> None:
self.shape = shape

def forward(self, x):
Expand Down Expand Up @@ -118,15 +118,15 @@ def backward(self, gy):
return broadcast_to(gy, self.x_shape)


def sum(x, axis: Optional[Tuple[int, ...]] = None, keepdims=False) -> Tensor:
def sum(x, axis: Optional[tuple[int, ...]] = None, keepdims=False) -> Tensor:
"""Sum elements along given axes."""
return Sum(axis, keepdims)(x)


class SumTo(Function):
"""Sum elements to target shape (inverse of broadcast_to)."""

def __init__(self, shape: Tuple[int, ...]):
def __init__(self, shape: tuple[int, ...]):
self.shape = shape

def forward(self, x):
Expand All @@ -137,7 +137,7 @@ def backward(self, gy):
return broadcast_to(gy, self.x_shape)


def sum_to(x, shape: Tuple[int, ...]) -> Tensor:
def sum_to(x, shape: tuple[int, ...]) -> Tensor:
"""Sum elements of x so result has `shape`."""
if x.shape == shape:
return as_tensor(x)
Expand All @@ -147,7 +147,7 @@ def sum_to(x, shape: Tuple[int, ...]) -> Tensor:
class BroadcastTo(Function):
"""Broadcast x to shape."""

def __init__(self, shape: Tuple[int, ...]) -> None:
def __init__(self, shape: tuple[int, ...]) -> None:
self.shape = shape

def forward(self, x):
Expand All @@ -158,7 +158,7 @@ def backward(self, gy):
return sum_to(gy, self.x_shape)


def broadcast_to(x, shape: Tuple[int, ...]) -> Tensor:
def broadcast_to(x, shape: tuple[int, ...]) -> Tensor:
"""Broadcast x to the given shape."""
if x.shape == shape:
return as_tensor(x)
Expand Down
4 changes: 2 additions & 2 deletions src/syna/functions/math.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -283,7 +283,7 @@ def linear_simple(x, W, b=None) -> Tensor:
return y


def mean(x, axis: Optional[Tuple[int, ...]] = None, keepdims=False) -> Tensor:
def mean(x, axis: Optional[tuple[int, ...]] = None, keepdims=False) -> Tensor:
"""Mean like torch.mean: mean over all elements by default, or over given axis/axes."""
x = as_tensor(x)
if axis is None:
Expand Down
8 changes: 4 additions & 4 deletions src/syna/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os
import weakref
from typing import Dict, Optional
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -64,7 +64,7 @@ def cleargrads(self):
for param in self.params():
param.cleargrad()

def _flatten_params(self, params_dict: Dict[str, Parameter], parent_key: str = ""):
def _flatten_params(self, params_dict: dict[str, Parameter], parent_key: str = ""):
"""Populate params_dict with flattened parameter names -> Parameter."""
for name in self._params:
obj = self.__dict__[name]
Expand All @@ -76,7 +76,7 @@ def _flatten_params(self, params_dict: Dict[str, Parameter], parent_key: str = "

def save_weights(self, path: str):
"""Save layer parameters to a compressed .npz file."""
params_dict: Dict[str, Parameter] = {}
params_dict: dict[str, Parameter] = {}
self._flatten_params(params_dict)
array_dict = {k: p.data for k, p in params_dict.items() if p is not None}
try:
Expand All @@ -90,7 +90,7 @@ def save_weights(self, path: str):
def load_weights(self, path: str):
"""Load parameters from a .npz file created by save_weights()."""
npz = np.load(path)
params_dict: Dict[str, Parameter] = {}
params_dict: dict[str, Parameter] = {}
self._flatten_params(params_dict)
for key, param in params_dict.items():
param.data = npz[key]
Expand Down
6 changes: 2 additions & 4 deletions src/syna/optim/adadelta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict

import numpy as np

from syna.optim.optimizer import Optimizer
Expand All @@ -12,8 +10,8 @@ def __init__(self, rho: float = 0.95, eps: float = 1e-6) -> None:
super().__init__()
self.rho = rho
self.eps = eps
self._msg: Dict[int, np.ndarray] = {}
self._msdx: Dict[int, np.ndarray] = {}
self._msg: dict[int, np.ndarray] = {}
self._msdx: dict[int, np.ndarray] = {}

def update_one(self, param) -> None:
msg = self._state(self._msg, param)
Expand Down
4 changes: 1 addition & 3 deletions src/syna/optim/adagrad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict

import numpy as np

from syna.optim.optimizer import Optimizer
Expand All @@ -12,7 +10,7 @@ def __init__(self, lr: float = 0.001, eps: float = 1e-8) -> None:
super().__init__()
self.lr = lr
self.eps = eps
self._hs: Dict[int, np.ndarray] = {}
self._hs: dict[int, np.ndarray] = {}

def update_one(self, param) -> None:
h = self._state(self._hs, param)
Expand Down
5 changes: 2 additions & 3 deletions src/syna/optim/adam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
from typing import Dict

import numpy as np

Expand All @@ -26,8 +25,8 @@ def __init__(
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self._ms: Dict[int, np.ndarray] = {}
self._vs: Dict[int, np.ndarray] = {}
self._ms: dict[int, np.ndarray] = {}
self._vs: dict[int, np.ndarray] = {}

def update(self, *args, **kwargs) -> None:
"""Increment time step and perform parameter updates."""
Expand Down
6 changes: 3 additions & 3 deletions src/syna/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import math
from typing import Callable, Dict, Iterable, List
from typing import Callable, Iterable

import numpy as np

Expand All @@ -23,7 +23,7 @@ class Optimizer:

def __init__(self) -> None:
self.target = None
self.hooks: List[Callable[[Iterable], None]] = []
self.hooks: list[Callable[[Iterable], None]] = []

def setup(self, target):
"""Attach optimizer to a target (model) which must provide params()."""
Expand All @@ -47,7 +47,7 @@ def add_hook(self, f: Callable[[Iterable], None]) -> None:
self.hooks.append(f)

# Utility for managing per-parameter state dicts (e.g., moments, accumulators).
def _state(self, store: Dict[int, np.ndarray], param: object) -> np.ndarray:
def _state(self, store: dict[int, np.ndarray], param: object) -> np.ndarray:
key = id(param)
if key not in store:
store[key] = np.zeros_like(param.data)
Expand Down
12 changes: 6 additions & 6 deletions src/syna/utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

import math
from typing import Any, Iterable, List, Tuple
from typing import Any, Iterable

import numpy as np

Expand All @@ -26,7 +26,7 @@ class DataLoader:
"""

def __init__(
self, dataset: Iterable[Tuple[Any, Any]], batch_size: int, shuffle: bool = True
self, dataset: Iterable[tuple[Any, Any]], batch_size: int, shuffle: bool = True
):
self.dataset = list(dataset)
self.batch_size = int(batch_size)
Expand All @@ -48,7 +48,7 @@ def reset(self) -> None:
def __iter__(self):
return self

def __next__(self) -> Tuple[np.ndarray, np.ndarray]:
def __next__(self) -> tuple[np.ndarray, np.ndarray]:
"""
Return the next batch (x, t) as NumPy arrays.
Raises StopIteration at the end of an epoch and resets internally.
Expand Down Expand Up @@ -90,17 +90,17 @@ class SeqDataLoader(DataLoader):
- Iteration yields exactly data_size // jump steps (i.e., max_iter inherited).
"""

def __init__(self, dataset: Iterable[Tuple[Any, Any]], batch_size: int):
def __init__(self, dataset: Iterable[tuple[Any, Any]], batch_size: int):
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=False)

def __next__(self) -> Tuple[np.ndarray, np.ndarray]:
def __next__(self) -> tuple[np.ndarray, np.ndarray]:
if self.iteration >= self.max_iter:
self.reset()
raise StopIteration

# jump sets the offset between streams to evenly partition the data
jump = max(1, self.data_size // max(1, self.batch_size))
indices: List[int] = [
indices: list[int] = [
(i * jump + self.iteration) % self.data_size for i in range(self.batch_size)
]
batch = [self.dataset[i] for i in indices]
Expand Down
4 changes: 2 additions & 2 deletions src/syna/utils/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import random
from collections import deque
from typing import Deque, Tuple
from typing import Deque

import gymnasium as gym
import matplotlib.pyplot as plt
Expand All @@ -28,7 +28,7 @@ class ReplayBuffer:
"""

def __init__(self, buffer_size: int, batch_size: int):
self.buffer: Deque[Tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(
self.buffer: Deque[tuple[np.ndarray, int, float, np.ndarray, bool]] = deque(
maxlen=buffer_size
)
self.batch_size = batch_size
Expand Down
12 changes: 6 additions & 6 deletions src/syna/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import os
import urllib.request
from typing import Optional, Tuple, Union
from typing import Optional, Union

import numpy as np

Expand All @@ -16,7 +16,7 @@
# --- array/tensor helpers -----------------------------------------------------------


def sum_to(x: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
def sum_to(x: np.ndarray, shape: tuple[int, ...]) -> np.ndarray:
"""
Sum elements of array `x` so that the result has shape `shape`.
This implements broadcasting-compatible sum reduction.
Expand All @@ -33,8 +33,8 @@ def sum_to(x: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:

def reshape_sum_backward(
gy: np.ndarray,
x_shape: Tuple[int, ...],
axis: Optional[Union[int, Tuple[int, ...]]],
x_shape: tuple[int, ...],
axis: Optional[Union[int, tuple[int, ...]]],
keepdims: bool,
) -> np.ndarray:
"""
Expand Down Expand Up @@ -74,7 +74,7 @@ def logsumexp(x: np.ndarray, axis: int = 1) -> np.ndarray:


def max_backward_shape(
x: np.ndarray, axis: Optional[Union[int, Tuple[int, ...]]]
x: np.ndarray, axis: Optional[Union[int, tuple[int, ...]]]
) -> list:
"""
Compute the shape of gradient for max reduction so the result can be
Expand Down Expand Up @@ -216,7 +216,7 @@ def get_file(url: str, file_name: Optional[str] = None) -> str:
return file_path


def pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
def pair(x: Union[int, tuple[int, int]]) -> tuple[int, int]:
"""Ensure `x` is a pair (tuple of two ints)."""
if isinstance(x, int):
return (x, x)
Expand Down