Skip to content

Commit 5695abf

Browse files
committed
Store states as dynamic list of states
1 parent fa29673 commit 5695abf

File tree

1 file changed

+16
-43
lines changed

1 file changed

+16
-43
lines changed

fenics-adapter/fenicsxprecice/interface.py

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import logging
33
import os
44
import sys
5-
from typing import NamedTuple
5+
from typing import NamedTuple, List
66

77
import dolfinx as dfx
8+
import copy
89
import numpy as np
910
import precice
10-
from turbine.materials import Material
1111

1212
logger = logging.getLogger("precice")
1313

@@ -106,34 +106,19 @@ def write_data(self):
106106
class SolverState:
107107
"""Stores the state of the solver, including displacement, velocity, acceleration, and time."""
108108

109-
def __init__(self, u, v, a, t):
110-
"""Initialize the SolverState object.
111-
112-
Parameters
113-
----------
114-
u : dfx.fem.Function
115-
Displacement function.
116-
v : dfx.fem.Function
117-
Velocity function.
118-
a : dfx.fem.Function
119-
Acceleration function.
120-
t : float
121-
Time.
122-
"""
123-
self.u = u
124-
self.v = v
125-
self.a = a
126-
self.t = t
109+
def __init__(self, states):
110+
"""Initialize the SolverState object."""
111+
states_cp = []
112+
for state in states:
113+
if isinstance(state, dfx.fem.Function):
114+
states_cp.append(state.copy())
115+
else:
116+
states_cp.append(copy.deepcopy(state))
117+
self.__state = states_cp
127118

128119
def get_state(self):
129-
"""Returns the state of the solver.
130-
131-
Returns
132-
-------
133-
tuple
134-
A tuple containing displacement, velocity, acceleration, and time.
135-
"""
136-
return self.u, self.v, self.a, self.t
120+
"""Returns the state of the solver."""
121+
return self.__state
137122

138123

139124
class Adapter:
@@ -346,21 +331,9 @@ def write_data(self, write_function):
346331
write_data = write_function.vector[self.interface_dof]
347332
self._interface.write_data(mesh_name, write_data_name, self._precice_vertex_ids, write_data)
348333

349-
def store_checkpoint(self, u, v, a, t):
350-
"""Stores the current state as a checkpoint.
351-
352-
Parameters
353-
----------
354-
u : dfx.fem.Function
355-
Displacement function.
356-
v : dfx.fem.Function
357-
Velocity function.
358-
a : dfx.fem.Function
359-
Acceleration function.
360-
t : float
361-
Time.
362-
"""
363-
self._checkpoint = SolverState(u.copy(), v.copy(), a.copy(), t)
334+
def store_checkpoint(self, states: List):
335+
"""Stores the current state as a checkpoint."""
336+
self._checkpoint = SolverState(states)
364337

365338
def retrieve_checkpoint(self):
366339
"""Retrieves the stored checkpoint state.

0 commit comments

Comments
 (0)