Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/performance test #850

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions openadapt/a11y/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""This module provides platform-specific implementations for window and element
interactions using accessibility APIs. It abstracts the platform differences
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove unnecessary indent

and provides a unified interface for retrieving the active window, finding
display elements, and getting element values.
"""

import sys

from loguru import logger

if sys.platform == "darwin":
from . import _macos as impl

role = "AXStaticText"
elif sys.platform in ("win32", "linux"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove linux for now

from . import _windows as impl

role = "Text"
else:
raise Exception(f"Unsupported platform: {sys.platform}")


def get_active_window():
"""Get the active window object.

Returns:
The active window object.
"""
try:
return impl.get_active_window()
except Exception as exc:
logger.warning(f"{exc=}")
return None


def get_element_value(active_window, role=role):
"""Find the display of active_window.

Args:
active_window: The parent window to search within.

Returns:
The found active_window.
"""
try:
return impl.get_element_value(active_window, role)
except Exception as exc:
logger.warning(f"{exc=}")
return None
61 changes: 61 additions & 0 deletions openadapt/a11y/_macos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import AppKit
import ApplicationServices


def get_attribute(element, attribute):
result, value = ApplicationServices.AXUIElementCopyAttributeValue(
element, attribute, None
)
if result == 0:
return value
return None


def find_element_by_attribute(element, attribute, value):
if get_attribute(element, attribute) == value:
return element
children = get_attribute(element, ApplicationServices.kAXChildrenAttribute) or []
for child in children:
found = find_element_by_attribute(child, attribute, value)
if found:
return found
return None


def get_active_window():
"""Get the active window object.

Returns:
AXUIElement: The active window object.
"""
workspace = AppKit.NSWorkspace.sharedWorkspace()
active_app = workspace.frontmostApplication()
app_element = ApplicationServices.AXUIElementCreateApplication(
active_app.processIdentifier()
)

error_code, focused_window = ApplicationServices.AXUIElementCopyAttributeValue(
app_element, ApplicationServices.kAXFocusedWindowAttribute, None
)
if error_code:
raise Exception("Could not get the active window.")
return focused_window


def get_element_value(element, role="AXStaticText"):
"""Get the value of a specific element .

Args:
element: The AXUIElement to search within.

Returns:
str: The value of the element, or an error message if not found.
"""
target_element = find_element_by_attribute(
element, ApplicationServices.kAXRoleAttribute, role
)
if not target_element:
return f"AXStaticText element not found."

value = get_attribute(target_element, ApplicationServices.kAXValueAttribute)
return value if value else f"No value for AXStaticText element."
44 changes: 44 additions & 0 deletions openadapt/a11y/_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from loguru import logger
import pywinauto
import re


def get_active_window() -> pywinauto.application.WindowSpecification:
"""Get the active window object.

Returns:
pywinauto.application.WindowSpecification: The active window object.
"""
app = pywinauto.application.Application(backend="uia").connect(active_only=True)
window = app.top_window()
return window.wrapper_object()


def get_element_value(active_window, role="Text"):
"""Find the display element.

Args:
active_window: The parent window to search within.
role (str): The role of the element to search for.

Returns:
The found display element value.

Raises:
ValueError: If the element is not found.
"""
try:
elements = active_window.descendants() # Retrieve all descendants
for elem in elements:
if (
elem.element_info.control_type == role
and elem.element_info.name.startswith("Display is")
):
# Extract the number from the element's name
match = re.search(r"[-+]?\d*\.?\d+", elem.element_info.name)
if match:
return str(match.group())
raise ValueError("Display element not found")
except Exception as exc:
logger.warning(f"Error in get_element_value: {exc}")
return None
2 changes: 1 addition & 1 deletion openadapt/app/dashboard/api/recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def attach_routes(self) -> APIRouter:
def get_recordings() -> dict[str, list[Recording]]:
"""Get all recordings."""
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
recordings = crud.get_recordings(session)
return {"recordings": recordings}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion openadapt/app/tray.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def populate_menu(self, menu: QMenu, action: Callable, action_type: str) -> None
action_type (str): The type of action to perform ["visualize", "replay"]
"""
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
recordings = crud.get_recordings(session)

self.recording_actions[action_type] = []

Expand Down
11 changes: 7 additions & 4 deletions openadapt/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Configuration module for OpenAdapt."""


from enum import Enum
from typing import Any, ClassVar, Type, Union
import json
Expand Down Expand Up @@ -33,6 +32,7 @@
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"
DB_FILE_PATH = (DATA_DIR_PATH / "openadapt.db").absolute()

STOP_STRS = [
"oa.stop",
Expand Down Expand Up @@ -124,7 +124,8 @@ class SegmentationAdapter(str, Enum):

# Database
DB_ECHO: bool = False
DB_URL: ClassVar[str] = f"sqlite:///{(DATA_DIR_PATH / 'openadapt.db').absolute()}"
DB_FILE_PATH: str = str(DB_FILE_PATH)
DB_URL: ClassVar[str] = f"sqlite:///{DB_FILE_PATH}"

# Error reporting
ERROR_REPORTING_ENABLED: bool = True
Expand Down Expand Up @@ -428,11 +429,13 @@ def show_alert() -> None:
"""Show an alert to the user."""
msg = QMessageBox()
msg.setIcon(QMessageBox.Warning)
msg.setText("""
msg.setText(
"""
An error has occurred. The development team has been notified.
Please join the discord server to get help or send an email to
[email protected]
""")
"""
)
discord_button = QPushButton("Join the discord server")
discord_button.clicked.connect(
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
Expand Down
26 changes: 23 additions & 3 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,27 @@ def delete_recording(session: SaSession, recording: Recording) -> None:
delete_video_file(recording_timestamp)


def get_all_recordings(session: SaSession) -> list[Recording]:
def get_recordings(session: SaSession, max_rows=None) -> list[Recording]:
"""Get all recordings.

Args:
session (sa.orm.Session): The database session.
max_rows: The number of recordings to return, starting from the most recent.
Defaults to all if max_rows is not specified.

Returns:
list[Recording]: A list of all original recordings.
"""
return (
query = (
session.query(Recording)
.filter(Recording.original_recording_id == None) # noqa: E711
.order_by(sa.desc(Recording.timestamp))
.all()
)

if max_rows:
query = query.limit(max_rows)
return query.all()


def get_all_scrubbed_recordings(
session: SaSession,
Expand Down Expand Up @@ -352,6 +357,21 @@ def get_recording(session: SaSession, timestamp: float) -> Recording:
return session.query(Recording).filter(Recording.timestamp == timestamp).first()


def get_recordings_by_desc(session: SaSession, description_str: str) -> list[Recording]:
"""Get recordings by task description.
Args:
session (sa.orm.Session): The database session.
task_description (str): The task description to search for.
Returns:
list[Recording]: A list of recordings whose task descriptions contain the given string.
"""
return (
session.query(Recording)
.filter(Recording.task_description.contains(description_str))
.all()
)


BaseModelType = TypeVar("BaseModelType")


Expand Down
Loading
Loading