Skip to content

Commit 09c6435

Browse files
committed
Changes for seamless combatibility with pyidi analyses returns
1 parent b6e12e2 commit 09c6435

File tree

3 files changed

+80
-26
lines changed

3 files changed

+80
-26
lines changed

pyidi/GUIs/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import typing
2+
13
try:
24
import PyQt6
35

46
HAS_PYQT6 = True
57
except ImportError:
68
HAS_PYQT6 = False
79

8-
if HAS_PYQT6:
10+
if HAS_PYQT6 or typing.TYPE_CHECKING:
911
from .subset_selection import SelectionGUI
1012
from .result_viewer import ResultViewer
1113
else:

pyidi/GUIs/result_viewer.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,47 @@
77
import sys
88

99
class ResultViewer(QtWidgets.QMainWindow):
10-
def __init__(self, video, displacements, grid, fps=30, magnification=1, point_size=10, colormap="cool"):
10+
def __init__(self, video, displacements, points, fps=30, magnification=1, point_size=10, colormap="cool"):
11+
"""
12+
The results from the pyidi analysis can directly be passed to this class:
13+
14+
- ``video``: can be a ``VideoReader`` object (or numpy array of correct shape).
15+
- ``displacements``: directly the return from the ``get_displacements`` method.
16+
- ``points``: the points used for the analysis, which were passed to the ``set_points`` method.
17+
18+
Parameters
19+
----------
20+
video : np.ndarray or VideoReader
21+
Array of shape (n_frames, height, width) containing the video frames.
22+
displacements : np.ndarray
23+
Array of shape (n_frames, n_points, 2) containing the displacement vectors.
24+
points : np.ndarray
25+
Array of shape (n_points, 2) containing the grid points.
26+
fps : int, optional
27+
Frames per second for the video playback, by default 30.
28+
magnification : int, optional
29+
Magnification factor for the displacements, by default 1.
30+
point_size : int, optional
31+
Size of the points in pixels, by default 10.
32+
colormap : str, optional
33+
Name of the colormap to use for the arrows, by default "cool".
34+
"""
35+
# Create QApplication if it doesn't exist
36+
app = QtWidgets.QApplication.instance()
37+
if app is None:
38+
app = QtWidgets.QApplication([])
39+
1140
super().__init__()
12-
self.video = video
13-
self.displacements = displacements
14-
self.grid = grid
41+
42+
# Coordinate transformation to match viewer function behavior
43+
from ..video_reader import VideoReader
44+
if isinstance(video, VideoReader):
45+
self.video = video.get_frames()
46+
else:
47+
self.video = video
48+
49+
self.displacements = displacements[:, :, ::-1] # Flip x,y coordinates
50+
self.grid = points[:, ::-1] + 0.5 # Flip x,y coordinates
1551
self.fps = fps
1652
self.magnification = magnification
1753
self.points_size = point_size
@@ -25,6 +61,14 @@ def __init__(self, video, displacements, grid, fps=30, magnification=1, point_si
2561

2662
self.init_ui()
2763
self.update_frame()
64+
65+
# Start the GUI
66+
self.show()
67+
# Only call sys.exit if not in IPython
68+
if not hasattr(sys, 'ps1'): # Not interactive
69+
sys.exit(app.exec())
70+
else:
71+
app.exec() # Don't raise SystemExit in IPython
2872

2973
def init_ui(self):
3074
# Style
@@ -173,6 +217,7 @@ def init_ui(self):
173217
# === Finalize ===
174218
self.setCentralWidget(central_widget)
175219
self.setWindowTitle("Displacement Viewer")
220+
self.resize(800, 600)
176221

177222

178223
def toggle_playback(self):
@@ -337,22 +382,10 @@ def viewer(frames, displacements, points, fps=30, magnification=1, point_size=10
337382
colormap : str, optional
338383
Name of the colormap to use for the arrows, by default "cool".
339384
"""
340-
points = points[:, ::-1]
341-
displacements = displacements[:, :, ::-1]
342-
343-
app = QtWidgets.QApplication.instance()
344-
if app is None:
345-
app = QtWidgets.QApplication([])
346-
347-
win = ResultViewer(frames, displacements, points, fps=fps, magnification=magnification, point_size=point_size, colormap=colormap)
348-
win.resize(800, 600)
349-
win.show()
350-
351-
# Only call sys.exit if not in IPython
352-
if not hasattr(sys, 'ps1'): # Not interactive
353-
sys.exit(app.exec())
354-
else:
355-
app.exec() # Don't raise SystemExit in IPythonys
385+
# This function is now just a wrapper for backward compatibility
386+
# The ResultViewer class handles everything internally
387+
ResultViewer(frames, displacements, points, fps=fps, magnification=magnification,
388+
point_size=point_size, colormap=colormap)
356389

357390

358391
if __name__ == "__main__":
@@ -362,5 +395,6 @@ def viewer(frames, displacements, points, fps=30, magnification=1, point_size=10
362395
displacements = 2 * (np.random.rand(n_points, n_frames, 2) - 0.5)
363396
grid = np.stack(np.meshgrid(np.linspace(50, 350, int(np.sqrt(n_points))),
364397
np.linspace(50, 250, int(np.sqrt(n_points)))), axis=-1).reshape(-1, 2)[:n_points]
365-
grid = grid[:, ::-1]
366-
viewer(frames, displacements, grid)
398+
399+
# Now you can directly call ResultViewer (no need to flip coordinates here since it's done internally)
400+
ResultViewer(frames, displacements, grid)

pyidi/GUIs/subset_selection.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@ def mouseDragEvent(self, ev, axis=None):
4141

4242
class SelectionGUI(QtWidgets.QMainWindow):
4343
def __init__(self, video):
44+
"""Initialize the selection GUI for manual subset selection.
45+
46+
To extract the points, use the ``get_points`` method or the ``points`` attribute.
47+
48+
Parameters
49+
----------
50+
video : VideoReader or np.ndarray
51+
The video to be analyzed. If a VideoReader object, it should be initialized with the video file.
52+
"""
4453
app = QtWidgets.QApplication.instance()
4554
if app is None:
4655
app = QtWidgets.QApplication([])
@@ -156,7 +165,11 @@ def __init__(self, video):
156165
self.pg_widget.scene().sigMouseClicked.connect(self.on_mouse_click)
157166

158167
# Set the initial image
159-
self.image_item.setImage(video)
168+
from ..video_reader import VideoReader
169+
if isinstance(video, VideoReader):
170+
self.frame = video.get_frame(0)
171+
172+
self.image_item.setImage(self.frame.T) # axis 0 is x, while image axis 0 is y
160173

161174
# Ensure method-specific widgets are visible on startup
162175
self.method_selected(self.button_group.checkedId())
@@ -204,7 +217,7 @@ def create_help_button(self, tooltip_text: str) -> QtWidgets.QToolButton:
204217
def ui_graphics(self):
205218
# Image viewer
206219
self.pg_widget = GraphicsLayoutWidget()
207-
self.view = BrushViewBox(parent_gui=self, lockAspect=True)
220+
self.view = BrushViewBox(parent_gui=self, lockAspect=True, invertY=True)
208221
self.pg_widget.addItem(self.view)
209222

210223

@@ -763,7 +776,12 @@ def set_image(self, img: np.ndarray):
763776

764777
def get_points(self):
765778
"""Get all selected points from manual and polygons."""
766-
return np.array(self.selected_points)
779+
points = np.array(self.selected_points)[:, ::-1] # set axis 0 to y and axis 1 to x
780+
return points
781+
782+
@property
783+
def points(self):
784+
return self.get_points()
767785

768786
def get_filtered_points(self):
769787
"""Get candidate points from automatic filtering."""

0 commit comments

Comments
 (0)