Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance take_screenshot for multi-monitor support #792

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ src

dist/
build/
openadapt/error.log
3 changes: 3 additions & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class SegmentationAdapter(str, Enum):
"children",
]

# Screenshot capture
CAPTURE_ALL_MONITORS: bool = False

@field_validator("SCRUB_FILL_COLOR")
@classmethod
def validate_scrub_fill_color(cls, v: Union[str, int]) -> int: # noqa: ANN102
Expand Down
73 changes: 66 additions & 7 deletions openadapt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from jinja2 import Environment, FileSystemLoader
from PIL import Image, ImageEnhance
from posthog import Posthog
import pyautogui
import argparse
import time

from openadapt.build_utils import is_running_from_executable, redirect_stdout_stderr
from openadapt.custom_logger import logger
Expand Down Expand Up @@ -48,7 +51,9 @@
from openadapt.custom_logger import filter_log_messages
from openadapt.db import db
from openadapt.models import ActionEvent
from config import Config

config = Config()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary. Please replace with from openadapt.config import config like it's used elsewhere.

# TODO: move to constants.py
EMPTY = (None, [], {}, "")
SCT = mss.mss()
Expand Down Expand Up @@ -412,17 +417,71 @@ def evenly_spaced(arr: list, N: list) -> list:
return [val for idx, val in enumerate(arr) if idx in idxs]




def get_current_monitor(monitors):
"""Determine the monitor where the cursor is currently located.

Args:
monitors (list): List of monitor dictionaries.

Returns:
dict: The monitor dictionary where the cursor is located.
"""
cursor_x, cursor_y = pyautogui.position()

for monitor in monitors:
if monitor['left'] <= cursor_x < monitor['left'] + monitor['width'] and \
monitor['top'] <= cursor_y < monitor['top'] + monitor['height']:
return monitor

# If not found, default to the first monitor
logger.warning(f"Cursor position ({cursor_x}, {cursor_y}) not found in any monitor. Defaulting to first monitor.")
return monitors[1]
onyedikachi-david marked this conversation as resolved.
Show resolved Hide resolved

def take_screenshot() -> Image.Image:
"""Take a screenshot.
"""Take a screenshot of the current monitor or all monitors.

Returns:
PIL.Image: The screenshot image.
PIL.Image.Image: The screenshot image.
"""
# monitor 0 is all in one
monitor = SCT.monitors[0]
sct_img = SCT.grab(monitor)
image = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
return image
global SCT
monitors = SCT.monitors[1:] # Skip the first entry which is a union of all monitors

if config.CAPTURE_ALL_MONITORS:
# Grab all monitors at once
sct_img = SCT.grab(SCT.monitors[0]) # Grab the union of all monitors
full_img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@onyedikachi-david can you please clarify why the rest of this block is necessary? Why not just return full_img directly?


# Determine the bounds of the combined image
min_left = min(monitor['left'] for monitor in monitors)
min_top = min(monitor['top'] for monitor in monitors)
max_right = max(monitor['left'] + monitor['width'] for monitor in monitors)
max_bottom = max(monitor['top'] + monitor['height'] for monitor in monitors)

total_width = max_right - min_left
total_height = max_bottom - min_top

combined_image = Image.new("RGB", (total_width, total_height))

for monitor in monitors:
x_offset = monitor['left'] - min_left
y_offset = monitor['top'] - min_top
monitor_img = full_img.crop((
monitor['left'], monitor['top'],
monitor['left'] + monitor['width'],
monitor['top'] + monitor['height']
))
combined_image.paste(monitor_img, (x_offset, y_offset))

return combined_image
else:
# Capture the current monitor
current_monitor = get_current_monitor(monitors)
sct_img = SCT.grab(current_monitor)
image = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
return image



def get_strategy_class_by_name() -> dict:
Expand Down
33 changes: 33 additions & 0 deletions tests/openadapt/test_monitors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Tests the take_screenshot function in openadapt/utils.py"""

import mss
import pytest
from unittest.mock import MagicMock, patch

from openadapt.utils import take_screenshot
from PIL import Image

def test_take_screenshot():
"""Test the take_screenshot function."""
image = take_screenshot()
assert isinstance(image, Image.Image)
assert image.size == (1920, 1080)

@patch('openadapt.utils.get_current_monitor')
@patch('mss.mss')
def test_take_screenshot_multiple_monitors(mock_mss, mock_get_current_monitor):
"""Test the take_screenshot function with multiple monitors."""
# Mock the return value of get_current_monitor to simulate the current monitor
mock_get_current_monitor.return_value = {'left': 0, 'top': 0, 'width': 1920, 'height': 1080}

# Mock the mss instance and its grab method
mock_sct = mock_mss.return_value.__enter__.return_value
mock_screenshot = MagicMock()
mock_screenshot.size = (1920, 1080)
mock_screenshot.bgra = b'\x00' * (1920 * 1080 * 4)
mock_sct.grab.return_value = mock_screenshot

image = take_screenshot()
assert isinstance(image, Image.Image)
# Assuming the function should capture the primary monitor
assert image.size == (1920, 1080)