Skip to content

Commit

Permalink
Add additional typing to qt components
Browse files Browse the repository at this point in the history
  • Loading branch information
JHolba committed Dec 19, 2024
1 parent 8f67e6a commit fcb1ddc
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 30 deletions.
20 changes: 10 additions & 10 deletions src/ert/gui/model/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@dataclass
class _Node(ABC):
class Node(ABC):
id_: str
parent: RootNode | IterNode | RealNode | None = None
children: (
Expand All @@ -24,7 +24,7 @@ def __repr__(self) -> str:
return f"Node<{type(self).__name__}>@{self.id_} with {parent}parent and {children}children"

@abstractmethod
def add_child(self, node: _Node) -> None:
def add_child(self, node: Node) -> None:
pass

def row(self) -> int:
Expand All @@ -37,12 +37,12 @@ def row(self) -> int:


@dataclass
class RootNode(_Node):
class RootNode(Node):
parent: None = field(default=None, init=False)
children: dict[str, IterNode] = field(default_factory=dict)
max_memory_usage: int | None = None

def add_child(self, node: _Node) -> None:
def add_child(self, node: Node) -> None:
node = cast(IterNode, node)
node.parent = self
self.children[node.id_] = node
Expand All @@ -55,12 +55,12 @@ class IterNodeData:


@dataclass
class IterNode(_Node):
class IterNode(Node):
parent: RootNode | None = None
data: IterNodeData = field(default_factory=IterNodeData)
children: dict[str, RealNode] = field(default_factory=dict)

def add_child(self, node: _Node) -> None:
def add_child(self, node: Node) -> None:
node = cast(RealNode, node)
node.parent = self
self.children[node.id_] = node
Expand All @@ -80,21 +80,21 @@ class RealNodeData:


@dataclass
class RealNode(_Node):
class RealNode(Node):
parent: IterNode | None = None
data: RealNodeData = field(default_factory=RealNodeData)
children: dict[str, ForwardModelStepNode] = field(default_factory=dict)

def add_child(self, node: _Node) -> None:
def add_child(self, node: Node) -> None:
node = cast(ForwardModelStepNode, node)
node.parent = self
self.children[node.id_] = node


@dataclass
class ForwardModelStepNode(_Node):
class ForwardModelStepNode(Node):
parent: RealNode | None
data: FMStepSnapshot = field(default_factory=lambda: FMStepSnapshot()) # noqa: PLW0108

def add_child(self, node: _Node) -> None:
def add_child(self, node: Node) -> None:
pass
11 changes: 6 additions & 5 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
from ert.gui.ertnotifier import ErtNotifier
from ert.gui.ertwidgets.message_box import ErtMessageBox
from ert.gui.model.fm_step_list import FMStepListProxyModel
from ert.gui.model.node import Node
from ert.gui.model.real_list import RealListModel
from ert.gui.model.snapshot import (
FM_STEP_COLUMNS,
FileRole,
Expand Down Expand Up @@ -179,8 +181,8 @@ def __init__(
QFrame.__init__(self, parent)
self.output_path = output_path
self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose)
self.setWindowFlags(Qt.WindowType.Window)
self.setWindowFlags(self.windowFlags() & ~Qt.WindowContextHelpButtonHint) # type: ignore
self.setWindowFlag(Qt.WindowType.Window, True)
self.setWindowFlag(Qt.WindowType.WindowContextHelpButtonHint, False)
self.setWindowTitle(f"Experiment - {config_file} {find_ert_info()}")

self._snapshot_model = SnapshotModel(self)
Expand Down Expand Up @@ -304,8 +306,7 @@ def on_snapshot_new_iteration(
) -> None:
if not parent.isValid():
index = self._snapshot_model.index(start, 0, parent)
# iteration = index.data(IterNum)
iteration = index.internalPointer().id_ # type: ignore
iteration = cast(Node, index.internalPointer()).id_
iter_row = start
self._iteration_progress_label.setText(
f"Progress for iteration {iteration}"
Expand All @@ -325,7 +326,7 @@ def on_snapshot_new_iteration(
def _select_real(self, index: QModelIndex) -> None:
if index.isValid():
real = index.row()
iter_ = index.model().get_iter() # type: ignore
iter_ = cast(RealListModel, index.model()).get_iter()
exec_hosts = None

iter_node = self._snapshot_model.root.children.get(str(iter_), None)
Expand Down
34 changes: 21 additions & 13 deletions src/ert/gui/tools/manage_experiments/storage_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import IntEnum
from typing import Any, overload
from typing import Any, cast, overload
from uuid import UUID

import humanize
Expand Down Expand Up @@ -52,7 +52,9 @@ def row(self) -> int:
return self._parent._children.index(self)
return 0

def data(self, index: QModelIndex, role: Qt.ItemDataRole) -> Any:
def data(
self, index: QModelIndex | QPersistentModelIndex, role: Qt.ItemDataRole
) -> Any:
if not index.isValid():
return None

Expand All @@ -79,7 +81,9 @@ def row(self) -> int:
return self._parent._children.index(self)
return 0

def data(self, index: QModelIndex, role: Qt.ItemDataRole) -> Any:
def data(
self, index: QModelIndex | QPersistentModelIndex, role: Qt.ItemDataRole
) -> Any:
if not index.isValid():
return None

Expand Down Expand Up @@ -113,7 +117,9 @@ def row(self) -> int:
return 0

def data(
self, index: QModelIndex, role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole
self,
index: QModelIndex | QPersistentModelIndex,
role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole,
) -> Any:
if not index.isValid():
return None
Expand All @@ -139,6 +145,9 @@ def data(
return None


ChildModel = ExperimentModel | EnsembleModel | RealizationModel


class StorageModel(QAbstractItemModel):
def __init__(self, storage: Storage):
super().__init__(None)
Expand Down Expand Up @@ -180,12 +189,9 @@ def columnCount(
def rowCount(
self, parent: QModelIndex | QPersistentModelIndex | None = None
) -> int:
if parent is None:
parent = QModelIndex()
if parent.isValid():
if isinstance(parent.internalPointer(), RealizationModel):
return 0
return len(parent.internalPointer()._children) # type: ignore
if parent is not None and parent.isValid():
data = cast(ChildModel | StorageModel, parent.internalPointer())
return 0 if isinstance(data, RealizationModel) else len(data._children)
else:
return len(self._children)

Expand All @@ -200,8 +206,8 @@ def parent(
if child is None or not child.isValid():
return QModelIndex()

child_item = child.internalPointer()
parentItem = child_item._parent # type: ignore
child_item = cast(ChildModel, child.internalPointer())
parentItem = child_item._parent

if parentItem == self:
return QModelIndex()
Expand Down Expand Up @@ -229,7 +235,9 @@ def data(
if not index.isValid():
return None

return index.internalPointer().data(index, role) # type:ignore
return cast(ChildModel | StorageModel, index.internalPointer()).data(
index, cast(Qt.ItemDataRole, role)
)

@override
def index(
Expand Down
4 changes: 2 additions & 2 deletions tests/ert/unit_tests/gui/simulation/view/test_realization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
FORWARD_MODEL_STATE_START,
REALIZATION_STATE_UNKNOWN,
)
from ert.gui.model.node import _Node
from ert.gui.model.node import Node
from ert.gui.model.snapshot import SnapshotModel
from ert.gui.simulation.view.realization import RealizationWidget
from tests.ert import SnapshotBuilder
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_selection_success(large_snapshot, qtbot):

def check_selection_cb(index):
node = index.internalPointer()
return isinstance(node, _Node) and str(node.id_) == str(selection_id)
return isinstance(node, Node) and str(node.id_) == str(selection_id)

with qtbot.waitSignal(
widget.itemClicked, timeout=30000, check_params_cb=check_selection_cb
Expand Down

0 comments on commit fcb1ddc

Please sign in to comment.