|
2 | 2 | import logging
|
3 | 3 | import os
|
4 | 4 | import sys
|
5 |
| -from typing import NamedTuple |
| 5 | +from typing import NamedTuple, List |
6 | 6 |
|
7 | 7 | import dolfinx as dfx
|
| 8 | +import copy |
8 | 9 | import numpy as np
|
9 | 10 | import precice
|
10 |
| -from turbine.materials import Material |
11 | 11 |
|
12 | 12 | logger = logging.getLogger("precice")
|
13 | 13 |
|
@@ -106,34 +106,19 @@ def write_data(self):
|
106 | 106 | class SolverState:
|
107 | 107 | """Stores the state of the solver, including displacement, velocity, acceleration, and time."""
|
108 | 108 |
|
109 |
| - def __init__(self, u, v, a, t): |
110 |
| - """Initialize the SolverState object. |
111 |
| -
|
112 |
| - Parameters |
113 |
| - ---------- |
114 |
| - u : dfx.fem.Function |
115 |
| - Displacement function. |
116 |
| - v : dfx.fem.Function |
117 |
| - Velocity function. |
118 |
| - a : dfx.fem.Function |
119 |
| - Acceleration function. |
120 |
| - t : float |
121 |
| - Time. |
122 |
| - """ |
123 |
| - self.u = u |
124 |
| - self.v = v |
125 |
| - self.a = a |
126 |
| - self.t = t |
| 109 | + def __init__(self, states): |
| 110 | + """Initialize the SolverState object.""" |
| 111 | + states_cp = [] |
| 112 | + for state in states: |
| 113 | + if isinstance(state, dfx.fem.Function): |
| 114 | + states_cp.append(state.copy()) |
| 115 | + else: |
| 116 | + states_cp.append(copy.deepcopy(state)) |
| 117 | + self.__state = states_cp |
127 | 118 |
|
128 | 119 | def get_state(self):
|
129 |
| - """Returns the state of the solver. |
130 |
| -
|
131 |
| - Returns |
132 |
| - ------- |
133 |
| - tuple |
134 |
| - A tuple containing displacement, velocity, acceleration, and time. |
135 |
| - """ |
136 |
| - return self.u, self.v, self.a, self.t |
| 120 | + """Returns the state of the solver.""" |
| 121 | + return self.__state |
137 | 122 |
|
138 | 123 |
|
139 | 124 | class Adapter:
|
@@ -346,21 +331,9 @@ def write_data(self, write_function):
|
346 | 331 | write_data = write_function.vector[self.interface_dof]
|
347 | 332 | self._interface.write_data(mesh_name, write_data_name, self._precice_vertex_ids, write_data)
|
348 | 333 |
|
349 |
| - def store_checkpoint(self, u, v, a, t): |
350 |
| - """Stores the current state as a checkpoint. |
351 |
| -
|
352 |
| - Parameters |
353 |
| - ---------- |
354 |
| - u : dfx.fem.Function |
355 |
| - Displacement function. |
356 |
| - v : dfx.fem.Function |
357 |
| - Velocity function. |
358 |
| - a : dfx.fem.Function |
359 |
| - Acceleration function. |
360 |
| - t : float |
361 |
| - Time. |
362 |
| - """ |
363 |
| - self._checkpoint = SolverState(u.copy(), v.copy(), a.copy(), t) |
| 334 | + def store_checkpoint(self, states: List): |
| 335 | + """Stores the current state as a checkpoint.""" |
| 336 | + self._checkpoint = SolverState(states) |
364 | 337 |
|
365 | 338 | def retrieve_checkpoint(self):
|
366 | 339 | """Retrieves the stored checkpoint state.
|
|
0 commit comments