Skip to content

Feat/performance test #850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions openadapt/a11y/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""This module provides platform-specific implementations for window and element
interactions using accessibility APIs. It abstracts the platform differences
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove unnecessary indent

and provides a unified interface for retrieving the active window, finding
display elements, and getting element values.
"""

import sys

from loguru import logger

if sys.platform == "darwin":
from . import _macos as impl

role = "AXStaticText"
elif sys.platform in ("win32", "linux"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove linux for now

from . import _windows as impl

role = "Text"
else:
raise Exception(f"Unsupported platform: {sys.platform}")


def get_active_window():
"""Get the active window object.

Returns:
The active window object.
"""
try:
return impl.get_active_window()
except Exception as exc:
logger.warning(f"{exc=}")
return None


def get_element_value(active_window, role=role):
"""Find the display of active_window.

Args:
active_window: The parent window to search within.

Returns:
The found active_window.
"""
try:
return impl.get_element_value(active_window, role)
except Exception as exc:
logger.warning(f"{exc=}")
return None
61 changes: 61 additions & 0 deletions openadapt/a11y/_macos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import AppKit
import ApplicationServices


def get_attribute(element, attribute):
result, value = ApplicationServices.AXUIElementCopyAttributeValue(
element, attribute, None
)
if result == 0:
return value
return None


def find_element_by_attribute(element, attribute, value):
if get_attribute(element, attribute) == value:
return element
children = get_attribute(element, ApplicationServices.kAXChildrenAttribute) or []
for child in children:
found = find_element_by_attribute(child, attribute, value)
if found:
return found
return None


def get_active_window():
"""Get the active window object.

Returns:
AXUIElement: The active window object.
"""
workspace = AppKit.NSWorkspace.sharedWorkspace()
active_app = workspace.frontmostApplication()
app_element = ApplicationServices.AXUIElementCreateApplication(
active_app.processIdentifier()
)

error_code, focused_window = ApplicationServices.AXUIElementCopyAttributeValue(
app_element, ApplicationServices.kAXFocusedWindowAttribute, None
)
if error_code:
raise Exception("Could not get the active window.")
return focused_window


def get_element_value(element, role="AXStaticText"):
"""Get the value of a specific element .

Args:
element: The AXUIElement to search within.

Returns:
str: The value of the element, or an error message if not found.
"""
target_element = find_element_by_attribute(
element, ApplicationServices.kAXRoleAttribute, role
)
if not target_element:
return f"AXStaticText element not found."

value = get_attribute(target_element, ApplicationServices.kAXValueAttribute)
return value if value else f"No value for AXStaticText element."
44 changes: 44 additions & 0 deletions openadapt/a11y/_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from loguru import logger
import pywinauto
import re


def get_active_window() -> pywinauto.application.WindowSpecification:
"""Get the active window object.

Returns:
pywinauto.application.WindowSpecification: The active window object.
"""
app = pywinauto.application.Application(backend="uia").connect(active_only=True)
window = app.top_window()
return window.wrapper_object()


def get_element_value(active_window, role="Text"):
"""Find the display element.

Args:
active_window: The parent window to search within.
role (str): The role of the element to search for.

Returns:
The found display element value.

Raises:
ValueError: If the element is not found.
"""
try:
elements = active_window.descendants() # Retrieve all descendants
for elem in elements:
if (
elem.element_info.control_type == role
and elem.element_info.name.startswith("Display is")
):
# Extract the number from the element's name
match = re.search(r"[-+]?\d*\.?\d+", elem.element_info.name)
if match:
return str(match.group())
raise ValueError("Display element not found")
except Exception as exc:
logger.warning(f"Error in get_element_value: {exc}")
return None
2 changes: 1 addition & 1 deletion openadapt/app/dashboard/api/recordings.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ def attach_routes(self) -> APIRouter:
def get_recordings() -> dict[str, list[Recording]]:
"""Get all recordings."""
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
recordings = crud.get_recordings(session)
return {"recordings": recordings}

@staticmethod
2 changes: 1 addition & 1 deletion openadapt/app/tray.py
Original file line number Diff line number Diff line change
@@ -463,7 +463,7 @@ def populate_menu(self, menu: QMenu, action: Callable, action_type: str) -> None
action_type (str): The type of action to perform ["visualize", "replay"]
"""
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
recordings = crud.get_recordings(session)

self.recording_actions[action_type] = []

11 changes: 7 additions & 4 deletions openadapt/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Configuration module for OpenAdapt."""


from enum import Enum
from typing import Any, ClassVar, Type, Union
import json
@@ -33,6 +32,7 @@
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"
DB_FILE_PATH = (DATA_DIR_PATH / "openadapt.db").absolute()

STOP_STRS = [
"oa.stop",
@@ -124,7 +124,8 @@ class SegmentationAdapter(str, Enum):

# Database
DB_ECHO: bool = False
DB_URL: ClassVar[str] = f"sqlite:///{(DATA_DIR_PATH / 'openadapt.db').absolute()}"
DB_FILE_PATH: str = str(DB_FILE_PATH)
DB_URL: ClassVar[str] = f"sqlite:///{DB_FILE_PATH}"

# Error reporting
ERROR_REPORTING_ENABLED: bool = True
@@ -428,11 +429,13 @@ def show_alert() -> None:
"""Show an alert to the user."""
msg = QMessageBox()
msg.setIcon(QMessageBox.Warning)
msg.setText("""
msg.setText(
"""
An error has occurred. The development team has been notified.
Please join the discord server to get help or send an email to
[email protected]
""")
"""
)
discord_button = QPushButton("Join the discord server")
discord_button.clicked.connect(
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
26 changes: 23 additions & 3 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
@@ -281,22 +281,27 @@ def delete_recording(session: SaSession, recording: Recording) -> None:
delete_video_file(recording_timestamp)


def get_all_recordings(session: SaSession) -> list[Recording]:
def get_recordings(session: SaSession, max_rows=None) -> list[Recording]:
"""Get all recordings.

Args:
session (sa.orm.Session): The database session.
max_rows: The number of recordings to return, starting from the most recent.
Defaults to all if max_rows is not specified.

Returns:
list[Recording]: A list of all original recordings.
"""
return (
query = (
session.query(Recording)
.filter(Recording.original_recording_id == None) # noqa: E711
.order_by(sa.desc(Recording.timestamp))
.all()
)

if max_rows:
query = query.limit(max_rows)
return query.all()


def get_all_scrubbed_recordings(
session: SaSession,
@@ -352,6 +357,21 @@ def get_recording(session: SaSession, timestamp: float) -> Recording:
return session.query(Recording).filter(Recording.timestamp == timestamp).first()


def get_recordings_by_desc(session: SaSession, description_str: str) -> list[Recording]:
"""Get recordings by task description.
Args:
session (sa.orm.Session): The database session.
task_description (str): The task description to search for.
Returns:
list[Recording]: A list of recordings whose task descriptions contain the given string.
"""
return (
session.query(Recording)
.filter(Recording.task_description.contains(description_str))
.all()
)


BaseModelType = TypeVar("BaseModelType")


151 changes: 151 additions & 0 deletions openadapt/scripts/generate_db_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from sqlalchemy import create_engine, inspect
from openadapt.db.db import Base
from openadapt.config import PARENT_DIR_PATH, RECORDING_DIR_PATH
import openadapt.db.crud as crud
from loguru import logger


def get_session():
"""
Establishes a database connection and returns a session and engine.
Returns:
tuple: A tuple containing the SQLAlchemy session and engine.
"""
db_url = RECORDING_DIR_PATH / "recording.db"
logger.info(f"Database URL: {db_url}")
engine = create_engine(f"sqlite:///{db_url}")
Base.metadata.create_all(bind=engine)
session = crud.get_new_session(read_only=True)
logger.info("Database connection established.")
return session, engine


def check_tables_exist(engine):
"""
Checks if the expected tables exist in the database.
Args:
engine: SQLAlchemy engine object.
Returns:
list: A list of table names in the database.
"""
inspector = inspect(engine)
tables = inspector.get_table_names()
expected_tables = [
"recording",
"action_event",
"screenshot",
"window_event",
"performance_stat",
"memory_stat",
]
for table_name in expected_tables:
table_exists = table_name in tables
logger.info(f"{table_name=} {table_exists=}")
return tables


def fetch_data(session):
"""
Fetches the most recent recordings and related data from the database.
Args:
session: SQLAlchemy session object.
Returns:
dict: A dictionary containing fetched data.
"""
# get the most recent three recordings
recordings = crud.get_recordings(session, max_rows=3)

action_events = []
screenshots = []
window_events = []
performance_stats = []
memory_stats = []

for recording in recordings:
action_events.extend(crud.get_action_events(session, recording))
screenshots.extend(crud.get_screenshots(session, recording))
window_events.extend(crud.get_window_events(session, recording))
performance_stats.extend(crud.get_perf_stats(session, recording))
memory_stats.extend(crud.get_memory_stats(session, recording))

data = {
"recordings": recordings,
"action_events": action_events,
"screenshots": screenshots,
"window_events": window_events,
"performance_stats": performance_stats,
"memory_stats": memory_stats,
}

# Debug prints to verify data fetching
logger.info(f"Recordings: {len(data['recordings'])} found.")
logger.info(f"Action Events: {len(data['action_events'])} found.")
logger.info(f"Screenshots: {len(data['screenshots'])} found.")
logger.info(f"Window Events: {len(data['window_events'])} found.")
logger.info(f"Performance Stats: {len(data['performance_stats'])} found.")
logger.info(f"Memory Stats: {len(data['memory_stats'])} found.")

return data


def format_sql_insert(table_name, rows):
"""
Formats SQL insert statements for a given table and rows.
Args:
table_name (str): The name of the table.
rows (list): A list of SQLAlchemy ORM objects representing the rows.
Returns:
str: A string containing the SQL insert statements.
"""
if not rows:
return ""

columns = rows[0].__table__.columns.keys()
sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES\n"
values = []

for row in rows:
row_values = [getattr(row, col) for col in columns]
row_values = [
f"'{value}'" if isinstance(value, str) else str(value)
for value in row_values
]
values.append(f"({', '.join(row_values)})")

sql += ",\n".join(values) + ";\n"
return sql


def dump_to_fixtures(filepath):
"""
Dumps the fetched data into an SQL file.
Args:
filepath (str): The path to the SQL file.
"""
session, engine = get_session()
check_tables_exist(engine)
rows_by_table_name = fetch_data(session)

for table_name, rows in rows_by_table_name.items():
if not rows:
logger.warning(f"No rows for {table_name=}")
continue
with open(filepath, "a", encoding="utf-8") as file:
logger.info(f"Writing {len(rows)=} to {filepath=} for {table_name=}")
file.write(f"-- Insert sample rows for {table_name}\n")
file.write(format_sql_insert(table_name, rows))


if __name__ == "__main__":

fixtures_path = PARENT_DIR_PATH / "tests/assets/fixtures.sql"

dump_to_fixtures(fixtures_path)
4 changes: 2 additions & 2 deletions openadapt/scripts/reset_db.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@

def reset_db() -> None:
"""Clears the database by removing the db file and running a db migration."""
if os.path.exists(config.DB_FPATH):
os.remove(config.DB_FPATH)
if os.path.exists(config.DB_FILE_PATH):
os.remove(config.DB_FILE_PATH)

# Prevents duplicate logging of config values by piping stderr
# and filtering the output.
54 changes: 54 additions & 0 deletions tests/openadapt/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from loguru import logger
from openadapt.db.crud import (
get_recordings_by_desc,
get_new_session,
)
from openadapt.replay import replay
from openadapt.a11y import (
get_active_window,
get_element_value,
)


# parametrized tests
@pytest.mark.parametrize(
"task_description, replay_strategy, expected_value, instructions",
[
("test_calculator", "VisualReplayStrategy", "6", " "),
("test_calculator", "VisualReplayStrategy", "8", "calculate 9-8+7"),
# ("test_spreadsheet", "NaiveReplayStrategy"),
# ("test_powerpoint", "NaiveReplayStrategy")
],
)
def test_replay(task_description, replay_strategy, expected_value, instructions):
# Get recordings which contain the string "test_calculator"
session = get_new_session(read_only=True)
recordings = get_recordings_by_desc(session, task_description)

assert (
len(recordings) > 0
), f"No recordings found with task description: {task_description}"
recording = recordings[0]

result = replay(
strategy_name=replay_strategy,
recording=recording,
instructions=instructions,
)
assert result is True, f"Replay failed for recording: {recording.id}"

active_window = get_active_window()
element_value = get_element_value(active_window)
logger.info(element_value)

assert (
element_value == expected_value
), f"Value mismatch: expected '{expected_value}', got '{element_value}'"

result_message = f"Value match: '{element_value}' == '{expected_value}'"
logger.info(result_message)


if __name__ == "__main__":
pytest.main()