Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions pr_agent/servers/github_action_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import os
from typing import Union

from line_profiler import profile as codeflash_line_profile

codeflash_line_profile.enable(output_prefix="/tmp/codeflash__yxejqcq/baseline_lprof")

from pr_agent.agent.pr_agent import PRAgent
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
Expand All @@ -14,11 +18,20 @@
from pr_agent.tools.pr_reviewer import PRReviewer


@codeflash_line_profile
def is_true(value: Union[str, bool]) -> bool:
if isinstance(value, bool):
# Direct type comparison for fastest builtin bool check
if type(value) is bool:
return value
# Optimize string handling: check type once for string or its subclasses
if isinstance(value, str):
return value.lower() == 'true'
# Check length first for 'true'
if len(value) != 4:
return False
# ASCII, case-insensitive: avoid allocating new string unless comparison matches (str.lower() is very fast for ASCII 4-char)
# Short-circuit by comparing against the 4 canonical variants
# But lower() then compare is fastest on CPython vs. tuple membership
return value.lower() == "true"
return False


Expand All @@ -32,11 +45,11 @@ def get_setting_or_env(key: str, default: Union[str, bool] = None) -> Union[str,

async def run_action():
# Get environment variables
GITHUB_EVENT_NAME = os.environ.get('GITHUB_EVENT_NAME')
GITHUB_EVENT_PATH = os.environ.get('GITHUB_EVENT_PATH')
OPENAI_KEY = os.environ.get('OPENAI_KEY') or os.environ.get('OPENAI.KEY')
OPENAI_ORG = os.environ.get('OPENAI_ORG') or os.environ.get('OPENAI.ORG')
GITHUB_TOKEN = os.environ.get('GITHUB_TOKEN')
GITHUB_EVENT_NAME = os.environ.get("GITHUB_EVENT_NAME")
GITHUB_EVENT_PATH = os.environ.get("GITHUB_EVENT_PATH")
OPENAI_KEY = os.environ.get("OPENAI_KEY") or os.environ.get("OPENAI.KEY")
OPENAI_ORG = os.environ.get("OPENAI_ORG") or os.environ.get("OPENAI.ORG")
GITHUB_TOKEN = os.environ.get("GITHUB_TOKEN")
# get_settings().set("CONFIG.PUBLISH_OUTPUT_PROGRESS", False)

# Check if required environment variables are set
Expand Down Expand Up @@ -65,7 +78,7 @@ async def run_action():

# Load the event payload
try:
with open(GITHUB_EVENT_PATH, 'r') as f:
with open(GITHUB_EVENT_PATH, "r") as f:
event_payload = json.load(f)
except json.decoder.JSONDecodeError as e:
print(f"Failed to parse JSON: {e}")
Expand All @@ -82,24 +95,25 @@ async def run_action():

# Append the response language in the extra instructions
try:
response_language = get_settings().config.get('response_language', 'en-us')
if response_language.lower() != 'en-us':
get_logger().info(f'User has set the response language to: {response_language}')
response_language = get_settings().config.get("response_language", "en-us")
if response_language.lower() != "en-us":
get_logger().info(f"User has set the response language to: {response_language}")

lang_instruction_text = f"Your response MUST be written in the language corresponding to locale code: '{response_language}'. This is crucial."
separator_text = "\n======\n\nIn addition, "

for key in get_settings():
setting = get_settings().get(key)
if str(type(setting)) == "<class 'dynaconf.utils.boxing.DynaBox'>":
if key.lower() in ['pr_description', 'pr_code_suggestions', 'pr_reviewer']:
if hasattr(setting, 'extra_instructions'):
if key.lower() in ["pr_description", "pr_code_suggestions", "pr_reviewer"]:
if hasattr(setting, "extra_instructions"):
extra_instructions = setting.extra_instructions

if lang_instruction_text not in str(extra_instructions):
updated_instructions = (
str(extra_instructions) + separator_text + lang_instruction_text
if extra_instructions else lang_instruction_text
if extra_instructions
else lang_instruction_text
)
setting.extra_instructions = updated_instructions
except Exception as e:
Expand All @@ -109,7 +123,9 @@ async def run_action():
action = event_payload.get("action")

# Retrieve the list of actions from the configuration
pr_actions = get_settings().get("GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"])
pr_actions = get_settings().get(
"GITHUB_ACTION_CONFIG.PR_ACTIONS", ["opened", "reopened", "ready_for_review", "review_requested"]
)

if action in pr_actions:
pr_url = event_payload.get("pull_request", {}).get("url")
Expand All @@ -127,8 +143,12 @@ async def run_action():

# Set the configuration for auto actions
get_settings().config.is_auto_command = True # Set the flag to indicate that the command is auto
get_settings().pr_description.final_update_message = False # No final update message when auto_describe is enabled
get_logger().info(f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}")
get_settings().pr_description.final_update_message = (
False # No final update message when auto_describe is enabled
)
get_logger().info(
f"Running auto actions: auto_describe={auto_describe}, auto_review={auto_review}, auto_improve={auto_improve}"
)

# invoke by default all three tools
if auto_describe is None or is_true(auto_describe):
Expand All @@ -147,7 +167,7 @@ async def run_action():
comment_body = event_payload.get("comment", {}).get("body")
try:
if GITHUB_EVENT_NAME == "pull_request_review_comment":
if '/ask' in comment_body:
if "/ask" in comment_body:
comment_body = handle_line_comments(event_payload, comment_body)
except Exception as e:
get_logger().error(f"Failed to handle line comments: {e}")
Expand All @@ -172,13 +192,11 @@ async def run_action():
provider = get_git_provider()(pr_url=url)
if is_pr:
await PRAgent().handle_request(
url, body, notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
)
url, body, notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes)
)
else:
await PRAgent().handle_request(url, body)


if __name__ == '__main__':
if __name__ == "__main__":
asyncio.run(run_action())