Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(clean up conversion) #101

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyhype/blocks/quad_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def from_block(self, from_block: BaseBlockGhost) -> None:
the interior and ghost blocks from the input block, which is of type `PrimitiveState`.

:type from_block: BaseBlock_With_Ghost
:param from_block: Block whos interior and ghost block states are used to update self.
:param from_block: Block whose interior and ghost block states are used to update self.

:return: None
"""
Expand Down
30 changes: 10 additions & 20 deletions pyhype/fvm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def ghost_bc_type(self, direction: int):
"""
Get the boundary condition type associated with the given direction

:param direction: Direction to check BC type
:param direction: Direction to check bc type
:return: bc type
"""
return self.parent_block.ghost[direction].bc_type
Expand Down Expand Up @@ -280,26 +280,16 @@ def _get_left_right_riemann_states(
:rtype: tuple(PrimitiveState, PrimitiveState)
:return: PrimitiveStates that hold the left and right states for the flux calculation
"""
left_arr = np.concatenate((left_ghost_state.data, left_state.data), axis=1)
right_arr = np.concatenate((right_state.data, right_ghost_state.data), axis=1)

if self.config.reconstruction_type is PrimitiveState:
left_state = PrimitiveState(self.config.fluid, array=left_arr)
right_state = PrimitiveState(self.config.fluid, array=right_arr)
return left_state, right_state

left_state = PrimitiveState(
left_state = self.config.reconstruction_type(
fluid=self.config.fluid,
state=self.config.reconstruction_type(
fluid=self.config.fluid, array=left_arr
),
)
right_state = PrimitiveState(
array=np.concatenate((left_ghost_state.data, left_state.data), axis=1),
).to_type(to_type=PrimitiveState, copy=False)

right_state = self.config.reconstruction_type(
fluid=self.config.fluid,
state=self.config.reconstruction_type(
fluid=self.config.fluid, array=right_arr
),
)
array=np.concatenate((right_state.data, right_ghost_state.data), axis=1),
).to_type(to_type=PrimitiveState, copy=False)

return left_state, right_state

def _get_boundary_flux_states(self, direction: int) -> [State]:
Expand Down Expand Up @@ -487,7 +477,7 @@ def _evaluate_north_south_flux(self) -> None:

def evaluate_flux(self) -> None:
"""
Calculates the fluxes at all cell boundaries. Solves the 1-D riemann problem along all of the rows and columns
Calculates the fluxes at all cell boundaries. Solves the 1-D riemann problem along all the rows and columns
of cells on the blocks in a sweeping (but unsplit) fashion.

:rtype: None
Expand Down
8 changes: 4 additions & 4 deletions pyhype/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ def reset(self, shape: tuple[int] = None):
def clear_cache(self) -> None:
self.cache.clear()

def from_state(self, state: State):
self.converter.from_state(state=self, from_state=state)
def from_state(self, state: State, copy: bool = True):
self.converter.from_state(state=self, from_state=state, copy=copy)
self.clear_cache()

def to_type(self, to_type: Type[State]):
return self.converter.to_type(state=self, to_type=to_type)
def to_type(self, to_type: Type[State], copy: bool = True):
return self.converter.to_type(state=self, to_type=to_type, copy=copy)

def from_array(self, array: np.ndarray):
if not isinstance(array, np.ndarray):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import os
from abc import ABC, abstractmethod
from abc import ABC
from typing import TYPE_CHECKING, Type, Callable

import pyhype.states as states
Expand All @@ -36,7 +36,7 @@ class ConverterLogic(ABC):
@classmethod
def get_func(
cls, state_type: Type[states.State]
) -> Callable[[states.State], np.ndarray]:
) -> Callable[[states.State, bool], np.ndarray]:
"""
Returns the conversion function that converts a Base type to a state_type.

Expand All @@ -49,32 +49,36 @@ def get_func(
return cls.to_conservative

@staticmethod
@abstractmethod
def to_primitive(state: states.State) -> np.ndarray:
def to_primitive(state: states.State, copy: bool = True) -> np.ndarray:
"""
Defines the conversion logic to convert from the base state type to the
primitive state type. This shall return a numpy array filled with the new
state values. The array then gets built into the correct State object type
inside the StateConverter.

Defaults to returning the input state array, or its copy.

:param state: The State object to convert
:param copy: To copy the state array if converting to the same type
:return: Numpy array with the correct data values
"""
raise NotImplementedError
return state.data.copy() if copy else state.data

@staticmethod
@abstractmethod
def to_conservative(state: states.State) -> np.ndarray:
def to_conservative(state: states.State, copy: bool = True) -> np.ndarray:
"""
Defines the conversion logic to convert from the base state type to the
conservative state type. This shall return a numpy array filled with the new
state values. The array then gets built into the correct State object type
inside the StateConverter.

Defaults to returning the input state array, or its copy.

:param state: The State object to convert
:param copy: To copy the state array if converting to the same type
:return: Numpy array with the correct data values
"""
raise NotImplementedError
return state.data.copy() if copy else state.data


class ConservativeConverter(ConverterLogic):
Expand All @@ -83,11 +87,12 @@ class ConservativeConverter(ConverterLogic):
"""

@staticmethod
def to_primitive(state: states.ConservativeState) -> np.ndarray:
def to_primitive(state: states.ConservativeState, copy: bool = True) -> np.ndarray:
"""
Logic that converts from a conservative state into a primitive state.

:param state: ConservativeState object to convert
:param copy: To copy the state array if converting to the same type
:return: Numpy array with the equivalent state in the primitive basis
"""
return np.dstack(
Expand All @@ -99,36 +104,15 @@ def to_primitive(state: states.ConservativeState) -> np.ndarray:
)
)

@staticmethod
def to_conservative(state: states.ConservativeState) -> np.ndarray:
"""
Logic that converts from a conservative state into a conservative state.
This simply returns a copy of the state array.

:param state: ConservativeState object to convert
:return: Numpy array with the equivalent state in the conservative basis
"""
return state.data.copy()


class PrimitiveConverter(ConverterLogic):
@staticmethod
def to_primitive(state: states.PrimitiveState):
"""
Logic that converts from a primitive state into a primitive state.
This simply returns a copy of the state array.

:param state: PrimitveState object to convert
:return: Numpy array with the equivalent state in the primitive basis
"""
return state.data.copy()

@staticmethod
def to_conservative(state: states.PrimitiveState) -> np.ndarray:
def to_conservative(state: states.PrimitiveState, copy: bool = True) -> np.ndarray:
"""
Logic that converts from a primitive state into a conservative state.

:param state: PrimitiveState object to convert
:param copy: To copy the state array if converting to the same type
:return: Numpy array with the equivalent state in the primitive basis
"""
return np.dstack(
Expand Down
57 changes: 36 additions & 21 deletions pyhype/states/converter/state_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,44 @@
from __future__ import annotations

from abc import ABC
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING

import pyhype.states as states
from pyhype.states.converter.concrete_defs import (
from pyhype.states.converter.converter_logic import (
PrimitiveConverter,
ConservativeConverter,
)

if TYPE_CHECKING:
from pyhype.states.base import State
from pyhype.states.converter.concrete_defs import ConverterLogic
from typing import Callable, Type
import numpy as np


class StateConverter(ABC):
@staticmethod
def get_converter(state_type: Type[states.State]) -> Type[ConverterLogic]:
def __init__(self):
self._converter_map = {
states.PrimitiveState: PrimitiveConverter,
states.ConservativeState: ConservativeConverter,
}

def _get_conversion_func(
self, from_type: Type[states.State], to_type: Type[states.State]
) -> Callable[[states.State, bool], np.ndarray]:
"""
Returns the converter type associated with state_type
Gets the conversion function needed to go from state type from_type to state
types to_type from the appropriate state converter logic class.

:param state_type: Type of state to get converter for
:return:
:param from_type: The type of class being converted from
:param to_type: The type of class being converter to.
:return: the conversion function that takes in a state of a certain type
and returns the data array for an equivalent state of a different type.
"""
if state_type == states.PrimitiveState:
return PrimitiveConverter
if state_type == states.ConservativeState:
return ConservativeConverter
return self._converter_map[from_type].get_func(state_type=to_type)

def from_state(self, state: states.State, from_state: states.State) -> None:
def from_state(
self, state: states.State, from_state: states.State, copy: bool = True
) -> None:
"""
Copies the data from from_state into state, while converting the data's variable
basis from from_state's type to state's type.
Expand All @@ -55,30 +65,35 @@ def from_state(self, state: states.State, from_state: states.State) -> None:

:param state: The state to copy data into
:param from_state: The state to copy data from
:param copy: To copy the state array if converting to the same type
:return: None
"""
if not state.shape == from_state.shape:
raise ValueError(
f"States must have equal shape, but state has {state.shape} and from_state has {from_state.shape}"
)
converter = self.get_converter(type(from_state))
func = converter.get_func(state_type=type(state))
state.data = func(state=from_state)
func = self._get_conversion_func(
from_type=type(from_state),
to_type=type(state),
)
state.data = func(state=from_state, copy=copy)

def to_type(
self,
state: states.State,
to_type: Type[states.State],
copy: bool = True,
) -> states.State:
"""
Creates a new State from state, with type to_type.

:param state: The state to create from
:param to_type: The type of the new state
:param copy: To copy the state array if converting to the same type
:return: State with type to_type
"""
converter = self.get_converter(type(state))
func = converter.get_func(state_type=to_type)
array = func(state=state)
created = to_type(fluid=state.fluid, array=array)
return created
func = self._get_conversion_func(
from_type=type(state),
to_type=to_type,
)
return to_type(fluid=state.fluid, array=func(state=state, copy=copy))