From fcb1ddcd472555bf3ae46cab40b8393d6ad3a704 Mon Sep 17 00:00:00 2001 From: Jon Holba Date: Thu, 19 Dec 2024 11:30:20 +0100 Subject: [PATCH] Add additional typing to qt components --- src/ert/gui/model/node.py | 20 +++++------ src/ert/gui/simulation/run_dialog.py | 11 +++--- .../tools/manage_experiments/storage_model.py | 34 ++++++++++++------- .../gui/simulation/view/test_realization.py | 4 +-- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/src/ert/gui/model/node.py b/src/ert/gui/model/node.py index 9d02dd220c9..90e051fb642 100644 --- a/src/ert/gui/model/node.py +++ b/src/ert/gui/model/node.py @@ -10,7 +10,7 @@ @dataclass -class _Node(ABC): +class Node(ABC): id_: str parent: RootNode | IterNode | RealNode | None = None children: ( @@ -24,7 +24,7 @@ def __repr__(self) -> str: return f"Node<{type(self).__name__}>@{self.id_} with {parent}parent and {children}children" @abstractmethod - def add_child(self, node: _Node) -> None: + def add_child(self, node: Node) -> None: pass def row(self) -> int: @@ -37,12 +37,12 @@ def row(self) -> int: @dataclass -class RootNode(_Node): +class RootNode(Node): parent: None = field(default=None, init=False) children: dict[str, IterNode] = field(default_factory=dict) max_memory_usage: int | None = None - def add_child(self, node: _Node) -> None: + def add_child(self, node: Node) -> None: node = cast(IterNode, node) node.parent = self self.children[node.id_] = node @@ -55,12 +55,12 @@ class IterNodeData: @dataclass -class IterNode(_Node): +class IterNode(Node): parent: RootNode | None = None data: IterNodeData = field(default_factory=IterNodeData) children: dict[str, RealNode] = field(default_factory=dict) - def add_child(self, node: _Node) -> None: + def add_child(self, node: Node) -> None: node = cast(RealNode, node) node.parent = self self.children[node.id_] = node @@ -80,21 +80,21 @@ class RealNodeData: @dataclass -class RealNode(_Node): +class RealNode(Node): parent: IterNode | None = None data: RealNodeData = field(default_factory=RealNodeData) children: dict[str, ForwardModelStepNode] = field(default_factory=dict) - def add_child(self, node: _Node) -> None: + def add_child(self, node: Node) -> None: node = cast(ForwardModelStepNode, node) node.parent = self self.children[node.id_] = node @dataclass -class ForwardModelStepNode(_Node): +class ForwardModelStepNode(Node): parent: RealNode | None data: FMStepSnapshot = field(default_factory=lambda: FMStepSnapshot()) # noqa: PLW0108 - def add_child(self, node: _Node) -> None: + def add_child(self, node: Node) -> None: pass diff --git a/src/ert/gui/simulation/run_dialog.py b/src/ert/gui/simulation/run_dialog.py index e49b90fdced..787140311fa 100644 --- a/src/ert/gui/simulation/run_dialog.py +++ b/src/ert/gui/simulation/run_dialog.py @@ -39,6 +39,8 @@ from ert.gui.ertnotifier import ErtNotifier from ert.gui.ertwidgets.message_box import ErtMessageBox from ert.gui.model.fm_step_list import FMStepListProxyModel +from ert.gui.model.node import Node +from ert.gui.model.real_list import RealListModel from ert.gui.model.snapshot import ( FM_STEP_COLUMNS, FileRole, @@ -179,8 +181,8 @@ def __init__( QFrame.__init__(self, parent) self.output_path = output_path self.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose) - self.setWindowFlags(Qt.WindowType.Window) - self.setWindowFlags(self.windowFlags() & ~Qt.WindowContextHelpButtonHint) # type: ignore + self.setWindowFlag(Qt.WindowType.Window, True) + self.setWindowFlag(Qt.WindowType.WindowContextHelpButtonHint, False) self.setWindowTitle(f"Experiment - {config_file} {find_ert_info()}") self._snapshot_model = SnapshotModel(self) @@ -304,8 +306,7 @@ def on_snapshot_new_iteration( ) -> None: if not parent.isValid(): index = self._snapshot_model.index(start, 0, parent) - # iteration = index.data(IterNum) - iteration = index.internalPointer().id_ # type: ignore + iteration = cast(Node, index.internalPointer()).id_ iter_row = start self._iteration_progress_label.setText( f"Progress for iteration {iteration}" @@ -325,7 +326,7 @@ def on_snapshot_new_iteration( def _select_real(self, index: QModelIndex) -> None: if index.isValid(): real = index.row() - iter_ = index.model().get_iter() # type: ignore + iter_ = cast(RealListModel, index.model()).get_iter() exec_hosts = None iter_node = self._snapshot_model.root.children.get(str(iter_), None) diff --git a/src/ert/gui/tools/manage_experiments/storage_model.py b/src/ert/gui/tools/manage_experiments/storage_model.py index 24462969cf1..80bdfbbe572 100644 --- a/src/ert/gui/tools/manage_experiments/storage_model.py +++ b/src/ert/gui/tools/manage_experiments/storage_model.py @@ -1,5 +1,5 @@ from enum import IntEnum -from typing import Any, overload +from typing import Any, cast, overload from uuid import UUID import humanize @@ -52,7 +52,9 @@ def row(self) -> int: return self._parent._children.index(self) return 0 - def data(self, index: QModelIndex, role: Qt.ItemDataRole) -> Any: + def data( + self, index: QModelIndex | QPersistentModelIndex, role: Qt.ItemDataRole + ) -> Any: if not index.isValid(): return None @@ -79,7 +81,9 @@ def row(self) -> int: return self._parent._children.index(self) return 0 - def data(self, index: QModelIndex, role: Qt.ItemDataRole) -> Any: + def data( + self, index: QModelIndex | QPersistentModelIndex, role: Qt.ItemDataRole + ) -> Any: if not index.isValid(): return None @@ -113,7 +117,9 @@ def row(self) -> int: return 0 def data( - self, index: QModelIndex, role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole + self, + index: QModelIndex | QPersistentModelIndex, + role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole, ) -> Any: if not index.isValid(): return None @@ -139,6 +145,9 @@ def data( return None +ChildModel = ExperimentModel | EnsembleModel | RealizationModel + + class StorageModel(QAbstractItemModel): def __init__(self, storage: Storage): super().__init__(None) @@ -180,12 +189,9 @@ def columnCount( def rowCount( self, parent: QModelIndex | QPersistentModelIndex | None = None ) -> int: - if parent is None: - parent = QModelIndex() - if parent.isValid(): - if isinstance(parent.internalPointer(), RealizationModel): - return 0 - return len(parent.internalPointer()._children) # type: ignore + if parent is not None and parent.isValid(): + data = cast(ChildModel | StorageModel, parent.internalPointer()) + return 0 if isinstance(data, RealizationModel) else len(data._children) else: return len(self._children) @@ -200,8 +206,8 @@ def parent( if child is None or not child.isValid(): return QModelIndex() - child_item = child.internalPointer() - parentItem = child_item._parent # type: ignore + child_item = cast(ChildModel, child.internalPointer()) + parentItem = child_item._parent if parentItem == self: return QModelIndex() @@ -229,7 +235,9 @@ def data( if not index.isValid(): return None - return index.internalPointer().data(index, role) # type:ignore + return cast(ChildModel | StorageModel, index.internalPointer()).data( + index, cast(Qt.ItemDataRole, role) + ) @override def index( diff --git a/tests/ert/unit_tests/gui/simulation/view/test_realization.py b/tests/ert/unit_tests/gui/simulation/view/test_realization.py index 8f35c5e25e2..015fb83f578 100644 --- a/tests/ert/unit_tests/gui/simulation/view/test_realization.py +++ b/tests/ert/unit_tests/gui/simulation/view/test_realization.py @@ -12,7 +12,7 @@ FORWARD_MODEL_STATE_START, REALIZATION_STATE_UNKNOWN, ) -from ert.gui.model.node import _Node +from ert.gui.model.node import Node from ert.gui.model.snapshot import SnapshotModel from ert.gui.simulation.view.realization import RealizationWidget from tests.ert import SnapshotBuilder @@ -104,7 +104,7 @@ def test_selection_success(large_snapshot, qtbot): def check_selection_cb(index): node = index.internalPointer() - return isinstance(node, _Node) and str(node.id_) == str(selection_id) + return isinstance(node, Node) and str(node.id_) == str(selection_id) with qtbot.waitSignal( widget.itemClicked, timeout=30000, check_params_cb=check_selection_cb