Skip to content

Commit

Permalink
Merge pull request #404 from skim0119/wip/401
Browse files Browse the repository at this point in the history
Use OperatorGroup for constrain and callback features
  • Loading branch information
skim0119 authored Jan 21, 2025
2 parents 61c53e9 + a7c1de0 commit 2ce3bde
Show file tree
Hide file tree
Showing 26 changed files with 464 additions and 355 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ flake8:
.PHONY: autoflake-check
autoflake-check:
poetry run autoflake --version
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: autoflake-format
Expand Down
5 changes: 1 addition & 4 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,9 @@
from elastica.utils import isqrt
from elastica.timestepper import (
integrate,
PositionVerlet,
PEFRL,
RungeKutta4,
EulerForward,
extend_stepper_interface,
)
from elastica.timestepper.symplectic_steppers import PositionVerlet, PEFRL
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod
from elastica.restart import save_state, load_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from elastica.typing import (
SystemType,
SystemCollectionType,
OperatorType,
StepType,
SteppersOperatorsType,
StateType,
)
from elastica.systems.protocol import ExplicitSystemProtocol
from .protocol import ExplicitStepperProtocol, MemoryProtocol
from elastica.experimental.timestepper.protocol import (
ExplicitSystemProtocol,
ExplicitStepperProtocol,
MemoryProtocol,
)


"""
Expand Down Expand Up @@ -166,10 +169,10 @@ class EulerForward(ExplicitStepperMixin):
Classical Euler Forward stepper. Stateless, coordinates operations only.
"""

def get_stages(self) -> list[OperatorType]:
def get_stages(self) -> list[StepType]:
return [self._first_stage]

def get_updates(self) -> list[OperatorType]:
def get_updates(self) -> list[StepType]:
return [self._first_update]

def _first_stage(
Expand Down Expand Up @@ -198,15 +201,15 @@ class RungeKutta4(ExplicitStepperMixin):
to be externally managed and allocated.
"""

def get_stages(self) -> list[OperatorType]:
def get_stages(self) -> list[StepType]:
return [
self._first_stage,
self._second_stage,
self._third_stage,
self._fourth_stage,
]

def get_updates(self) -> list[OperatorType]:
def get_updates(self) -> list[StepType]:
return [
self._first_update,
self._second_update,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Iterator, TypeVar, Generic, Type
from elastica.timestepper.protocol import ExplicitStepperProtocol
from elastica.typing import SystemCollectionType
from elastica.experimental.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)
from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol

from copy import copy

Expand All @@ -12,11 +16,6 @@ def make_memory_for_explicit_stepper(
) -> "MemoryCollection":
# TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.)

from elastica.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)

# is_this_system_a_collection = is_system_a_collection(system)

memory_cls: Type
Expand Down
86 changes: 86 additions & 0 deletions elastica/experimental/timestepper/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Protocol

from elastica.typing import StepType, StateType
from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol
from elastica.timestepper.protocol import StepperProtocol

import numpy as np


class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
# Need to be refactored as the explicit stepper is further developed.
def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ...
@property
def state(self) -> StateType: ...
@state.setter
def state(self, state: StateType) -> None: ...
@property
def n_elems(self) -> int: ...


class MemoryProtocol(Protocol):
@property
def initial_state(self) -> bool: ...


class ExplicitStepperProtocol(StepperProtocol, Protocol):
"""symplectic stepper protocol."""

def get_stages(self) -> list[StepType]: ...

def get_updates(self) -> list[StepType]: ...


# class _LinearExponentialIntegratorMixin:
# """
# Linear Exponential integrator mixin wrapper.
# """
#
# def __init__(self):
# pass
#
# def _do_stage(self, System, Memory, time, dt):
# # TODO : Make more general, system should not be calculating what the state
# # transition matrix directly is, but rather it should just give
# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt)
#
# def _do_update(self, System, Memory, time, dt):
# # FIXME What's the right formula when doing update?
# # System.linearly_evolving_state = _batch_matmul(
# # System.linearly_evolving_state,
# # Memory.linear_operator
# # )
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator
# )
# return time + dt
#
# def _first_prefactor(self, dt):
# """Prefactor call to satisfy interface of SymplecticStepper. Should never
# be used in actual code.
#
# Parameters
# ----------
# dt : the time step of simulation
#
# Raises
# ------
# RuntimeError
# """
# raise RuntimeError(
# "Symplectic prefactor of LinearExponentialIntegrator should not be called!"
# )
#
# # Code repeat!
# # Easy to avoid, but keep for performance.
# def _do_one_step(self, System, time, prefac):
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk",
# System.linearly_evolving_state,
# System.get_linear_state_transition_operator(time, prefac),
# )
# return (
# time # TODO fix hack that treats time separately here. Shuold be time + dt
# )
# # return time + dt
19 changes: 14 additions & 5 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Basic coordinating for multiple, smaller systems that have an independently integrable
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
"""
from typing import Type, Generator, Iterable, Any, overload
from typing import Type, Generator, Any, overload
from typing import final
from elastica.typing import (
SystemType,
Expand All @@ -27,6 +27,7 @@

from .memory_block import construct_memory_block_structures
from .operator_group import OperatorGroupFIFO
from .protocol import ModuleProtocol


class BaseSystemCollection(MutableSequence):
Expand Down Expand Up @@ -55,10 +56,18 @@ def __init__(self) -> None:
# Collection of functions. Each group is executed as a collection at the different steps.
# Each component (Forcing, Connection, etc.) registers the executable (callable) function
# in the group that that needs to be executed. These should be initialized before mixin.
self._feature_group_synchronize: Iterable[OperatorType] = OperatorGroupFIFO()
self._feature_group_constrain_values: list[OperatorType] = []
self._feature_group_constrain_rates: list[OperatorType] = []
self._feature_group_callback: list[OperatorCallbackType] = []
self._feature_group_synchronize: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_constrain_values: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_constrain_rates: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_callback: OperatorGroupFIFO[
OperatorCallbackType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_finalize: list[OperatorFinalizeType] = []
# We need to initialize our mixin classes
super().__init__()
Expand Down
34 changes: 16 additions & 18 deletions elastica/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType
from .protocol import ModuleProtocol

import functools

import numpy as np

from elastica.callback_functions import CallBackBaseClass
Expand All @@ -29,9 +31,7 @@ class CallBacks:

def __init__(self: SystemCollectionProtocol) -> None:
self._callback_list: list[ModuleProtocol] = []
self._callback_operators: list[tuple[int, CallBackBaseClass]] = []
super(CallBacks, self).__init__()
self._feature_group_callback.append(self._callback_execution)
self._feature_group_finalize.append(self._finalize_callback)

def collect_diagnostics(
Expand All @@ -54,30 +54,28 @@ def collect_diagnostics(
sys_idx: SystemIdxType = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_callbacks: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callbacks)
_callback: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callback)
self._feature_group_callback.append_id(_callback)

return _callbacks
return _callback

def _finalize_callback(self: SystemCollectionProtocol) -> None:
# dev : the first index stores the rod index to collect data.
self._callback_operators = [
(callback.id(), callback.instantiate()) for callback in self._callback_list
]
for callback in self._callback_list:
sys_id = callback.id()
callback_instance = callback.instantiate()

callback_operator = functools.partial(
callback_instance.make_callback, system=self[sys_id]
)
self._feature_group_callback.add_operators(callback, [callback_operator])

self._callback_list.clear()
del self._callback_list

# First callback execution
time = np.float64(0.0)
self._callback_execution(time=time, current_step=0)

def _callback_execution(
self: SystemCollectionProtocol,
time: np.float64,
current_step: int,
) -> None:
for sys_id, callback in self._callback_operators:
callback.make_callback(self[sys_id], time, current_step)
self.apply_callbacks(time=np.float64(0.0), current_step=0)


class _CallBack:
Expand Down
48 changes: 30 additions & 18 deletions elastica/modules/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, Type, cast
from typing_extensions import Self

import functools

import numpy as np

from elastica.boundary_conditions import ConstraintBase
Expand Down Expand Up @@ -36,8 +38,6 @@ class Constraints:
def __init__(self: SystemCollectionProtocol) -> None:
self._constraints_list: list[ModuleProtocol] = []
super(Constraints, self).__init__()
self._feature_group_constrain_values.append(self._constrain_values)
self._feature_group_constrain_rates.append(self._constrain_rates)
self._feature_group_finalize.append(self._finalize_constraints)

def constrain(
Expand All @@ -62,6 +62,8 @@ def constrain(
# Create _Constraint object, cache it and return to user
_constraint: ModuleProtocol = _Constraint(sys_idx)
self._constraints_list.append(_constraint)
self._feature_group_constrain_values.append_id(_constraint)
self._feature_group_constrain_rates.append_id(_constraint)

return _constraint

Expand All @@ -71,11 +73,14 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:
periodic boundaries, a new constrain for memory block rod added called as _ConstrainPeriodicBoundaries. This
constrain will synchronize the only periodic boundaries of position, director, velocity and omega variables.
"""
from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries

for block in self.block_systems():
# append the memory block to the simulation as a system. Memory block is the final system in the simulation.
if hasattr(block, "ring_rod_flag"):
from elastica._synchronize_periodic_boundary import (
_ConstrainPeriodicBoundaries,
)

# Apply the constrain to synchronize the periodic boundaries of the memory rod. Find the memory block
# sys idx among other systems added and then apply boundary conditions.
memory_block_idx = self.get_system_index(block)
Expand All @@ -89,31 +94,38 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:

# dev : the first index stores the rod index to apply the boundary condition
# to.
self._constraints_operators = [
(constraint.id(), constraint.instantiate(self[constraint.id()]))
for constraint in self._constraints_list
]

# Sort from lowest id to highest id for potentially better memory access
# _constraints contains list of tuples. First element of tuple is rod number and
# following elements are the type of boundary condition such as
# [(0, ConstraintBase, OneEndFixedBC), (1, HelicalBucklingBC), ... ]
# Thus using lambda we iterate over the list of tuples and use rod number (x[0])
# to sort constraints.
self._constraints_operators.sort(key=lambda x: x[0])
self._constraints_list.sort(key=lambda x: x.id())
for constraint in self._constraints_list:
sys_id = constraint.id()
constraint_instance = constraint.instantiate(self[sys_id])

constrain_values = functools.partial(
constraint_instance.constrain_values, system=self[sys_id]
)
constrain_rates = functools.partial(
constraint_instance.constrain_rates, system=self[sys_id]
)

self._feature_group_constrain_values.add_operators(
constraint, [constrain_values]
)
self._feature_group_constrain_rates.add_operators(
constraint, [constrain_rates]
)

# At t=0.0, constrain all the boundary conditions (for compatability with
# initial conditions)
self._constrain_values(time=np.float64(0.0))
self._constrain_rates(time=np.float64(0.0))

def _constrain_values(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, constraint in self._constraints_operators:
constraint.constrain_values(self[sys_id], time)
self.constrain_values(time=np.float64(0.0))
self.constrain_rates(time=np.float64(0.0))

def _constrain_rates(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, constraint in self._constraints_operators:
constraint.constrain_rates(self[sys_id], time)
self._constraints_list = []
del self._constraints_list


class _Constraint:
Expand Down
Loading

0 comments on commit 2ce3bde

Please sign in to comment.