Skip to content

Commit

Permalink
refactor: 优化 handler 相关函数
Browse files Browse the repository at this point in the history
  • Loading branch information
he0119 committed Nov 30, 2024
1 parent 20b9048 commit c4848fd
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 342 deletions.
16 changes: 16 additions & 0 deletions src/plugins/github/models/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
class GitHandler(BaseModel):
"""Git 操作"""

def checkout_branch(self, branch_name: str):
"""检出分支"""

run_shell_command(["git", "checkout", branch_name])

def checkout_remote_branch(self, branch_name: str):
"""检出远程分支"""

run_shell_command(["git", "fetch", "origin", branch_name])
run_shell_command(["git", "checkout", branch_name])

def commit_and_push(self, message: str, branch_name: str, author: str):
"""提交并推送"""

Expand Down Expand Up @@ -36,3 +47,8 @@ def delete_origin_branch(self, branch_name: str):
"""删除远程分支"""

run_shell_command(["git", "push", "origin", "--delete", branch_name])

def switch_branch(self, branch_name: str):
"""切换分支"""

run_shell_command(["git", "switch", "-C", branch_name])
98 changes: 44 additions & 54 deletions src/plugins/github/models/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def update_issue_content(self, body: str, issue_number: int):
async def create_dispatch_event(
self, event_type: str, client_payload: dict, repo: RepoInfo | None = None
):
"""创建触发事件"""
if repo is None:
repo = self.repo_info

Expand Down Expand Up @@ -100,6 +101,20 @@ async def get_pull_requests_by_label(self, label: str) -> list[PullRequestSimple
pull for pull in pulls if label in [label.name for label in pull.labels]
]

async def get_pull_request_by_branch(self, branch_name: str) -> PullRequestSimple:
"""根据分支的名称获取对应拉取请求"""
pulls = (
await self.bot.rest.pulls.async_list(
**self.repo_info.model_dump(),
head=f"{self.repo_info.owner}:{branch_name}",
)
).parsed_data

if not pulls:
raise ValueError(f"找不到分支 {branch_name} 对应的拉取请求")

return pulls[0]

async def get_pull_request(self, pull_number: int):
"""获取拉取请求"""
return (
Expand All @@ -109,16 +124,14 @@ async def get_pull_request(self, pull_number: int):
).parsed_data

async def draft_pull_request(self, branch_name: str):
"""
将拉取请求转换为草稿
"""
pulls = (
await self.bot.rest.pulls.async_list(
**self.repo_info.model_dump(),
head=f"{self.repo_info.owner}:{branch_name}",
)
).parsed_data
if pulls and (pull := pulls[0]) and not pull.draft:
"""将拉取请求转换为草稿"""
try:
pull = await self.get_pull_request_by_branch(branch_name)
except ValueError:
logger.info("未找到对应的拉取请求,无需转换")
return

if not pull.draft:
await self.bot.async_graphql(
query="""mutation convertPullRequestToDraft($pullRequestId: ID!) {
convertPullRequestToDraft(input: {pullRequestId: $pullRequestId}) {
Expand All @@ -128,6 +141,8 @@ async def draft_pull_request(self, branch_name: str):
variables={"pullRequestId": pull.node_id},
)
logger.info("没通过检查,已将之前的拉取请求转换为草稿")
else:
logger.info("拉取请求已为草稿状态,无需转换")

Check warning on line 145 in src/plugins/github/models/github.py

View check run for this annotation

Codecov / codecov/patch

src/plugins/github/models/github.py#L145

Added line #L145 was not covered by tests

async def merge_pull_request(
self,
Expand All @@ -142,44 +157,22 @@ async def merge_pull_request(
)
logger.info(f"拉取请求 #{pull_number} 已合并")

async def get_pull_request_by_branch(self, branch_name: str) -> PullRequestSimple:
"""根据分支的名称获取对应的拉取请求实例"""
return (
await self.bot.rest.pulls.async_list(
**self.repo_info.model_dump(),
head=f"{self.repo_info.owner}:{branch_name}",
)
).parsed_data[0]

async def update_pull_request_status(self, title: str, branch_name: str):
"""拉取请求若为草稿状态则标记为可评审,若标题不符则修改标题"""
pull = await self.get_pull_request_by_branch(branch_name)
if pull.title != title:
await self.bot.rest.pulls.async_update(
**self.repo_info.model_dump(), pull_number=pull.number, title=title
)
logger.info(f"拉取请求标题已修改为 {title}")
await self.update_pull_request_title(title, pull.number)
if pull.draft:
await self.bot.async_graphql(
query="""mutation markPullRequestReadyForReview($pullRequestId: ID!) {
markPullRequestReadyForReview(input: {pullRequestId: $pullRequestId}) {
clientMutationId
}
}""",
variables={"pullRequestId": pull.node_id},
)
logger.info("拉取请求已标记为可评审")
await self.ready_pull_request(pull.node_id)

async def create_pull_request(
self,
base_branch: str,
title: str,
branch_name: str,
label: str | list[str],
body: str = "",
) -> int:
"""创建拉取请求并分配标签,若存在请求会导致 raise RequestFailed"""

"""创建拉取请求并分配标签,若当前分支已存在对应拉去请求,会报错 RequestFailed"""
resp = await self.bot.rest.pulls.async_create(
**self.repo_info.model_dump(),
title=title,
Expand All @@ -189,16 +182,21 @@ async def create_pull_request(
)
pull = resp.parsed_data

# 自动给拉取请求添加标签
logger.info("拉取请求创建完毕")
return pull.number

async def add_labels(self, issue_number: int, labels: str | list[str]):
"""添加标签"""
labels = [labels] if isinstance(labels, str) else labels

await self.bot.rest.issues.async_add_labels(
**self.repo_info.model_dump(),
issue_number=pull.number,
labels=[label] if isinstance(label, str) else label,
issue_number=issue_number,
labels=labels,
)
logger.info("拉取请求创建完毕")
return pull.number
logger.info(f"标签 {labels} 已添加")

async def ready_pull_request(self, node_id: int):
async def ready_pull_request(self, node_id: str):
"""将拉取请求标记为可评审"""
await self.bot.async_graphql(
query="""mutation markPullRequestReadyForReview($pullRequestId: ID!) {
Expand All @@ -210,20 +208,12 @@ async def ready_pull_request(self, node_id: int):
)
logger.info("拉取请求已标记为可评审")

async def update_pull_request_title(self, title: str, branch_name: int):
async def update_pull_request_title(self, title: str, pull_number: int) -> None:
"""修改拉取请求标题"""
pull = (
await self.bot.rest.pulls.async_list(
**self.repo_info.model_dump(),
head=f"{self.repo_info.owner}:{branch_name}",
)
).parsed_data[0]

if pull.title != title:
await self.bot.rest.pulls.async_update(
**self.repo_info.model_dump(), pull_number=pull.number, title=title
)
logger.info(f"拉取请求标题已修改为 {title}")
await self.bot.rest.pulls.async_update(
**self.repo_info.model_dump(), pull_number=pull_number, title=title
)
logger.info(f"拉取请求标题已修改为 {title}")

async def get_user_name(self, account_id: int):
"""根据用户 ID 获取用户名"""
Expand Down
5 changes: 1 addition & 4 deletions src/plugins/github/models/issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,12 @@ async def create_pull_request(
base_branch: str,
title: str,
branch_name: str,
label: str | list[str],
body: str = "",
):
if not body:
body = f"resolve #{self.issue_number}"

return await super().create_pull_request(
base_branch, title, branch_name, label, body
)
return await super().create_pull_request(base_branch, title, branch_name, body)

async def should_skip_test(self) -> bool:
"""判断评论是否包含跳过的标记"""
Expand Down
31 changes: 14 additions & 17 deletions src/plugins/github/plugins/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,22 @@

from src.plugins.github.constants import CONFIG_LABEL, TITLE_MAX_LENGTH
from src.plugins.github.depends import (
RepoInfo,
bypass_git,
get_github_handler,
get_installation_id,
get_issue_handler,
get_repo_info,
get_type_by_labels_name,
install_pre_commit_hooks,
is_bot_triggered_workflow,
)
from src.plugins.github.models import IssueHandler
from src.plugins.github.models.github import GithubHandler
from src.plugins.github.plugins.publish.render import render_comment
from src.plugins.github.plugins.publish.utils import (
ensure_issue_plugin_test_button,
ensure_issue_plugin_test_button_in_progress,
)
from src.plugins.github.plugins.remove.depends import check_labels
from src.plugins.github.typing import IssuesEvent
from src.plugins.github.utils import run_shell_command
from src.providers.validation.models import PublishType

from .constants import BRANCH_NAME_PREFIX, COMMIT_MESSAGE_PREFIX, RESULTS_BRANCH
Expand Down Expand Up @@ -80,8 +78,7 @@ async def handle_remove_check(
await ensure_issue_plugin_test_button_in_progress(handler)

# 需要先切换到结果分支
run_shell_command(["git", "fetch", "origin", RESULTS_BRANCH])
run_shell_command(["git", "checkout", RESULTS_BRANCH])
handler.checkout_remote_branch(RESULTS_BRANCH)

# 检查是否满足发布要求
# 仅在通过检查的情况下创建拉取请求
Expand All @@ -103,16 +100,19 @@ async def handle_remove_check(
commit_message = f"{COMMIT_MESSAGE_PREFIX} {result.type.value.lower()} {result.name} (#{handler.issue_number})"

# 创建新分支
run_shell_command(["git", "switch", "-C", branch_name])
handler.switch_branch(branch_name)
# 更新文件
update_file(result)
handler.commit_and_push(commit_message, branch_name, handler.author)
# 创建拉取请求
try:
await handler.create_pull_request(
pull_number = await handler.create_pull_request(
RESULTS_BRANCH,
title,
branch_name,
)
await handler.add_labels(
pull_number,
[result.type.value, CONFIG_LABEL],
)
except RequestFailed:
Expand Down Expand Up @@ -146,21 +146,18 @@ async def review_submitted_rule(
auto_merge_matcher = on_type(PullRequestReviewSubmitted, rule=review_submitted_rule)


@auto_merge_matcher.handle(
parameterless=[Depends(bypass_git), Depends(install_pre_commit_hooks)]
)
@auto_merge_matcher.handle(parameterless=[Depends(bypass_git)])
async def handle_auto_merge(
bot: GitHubBot,
event: PullRequestReviewSubmitted,
installation_id: int = Depends(get_installation_id),
repo_info: RepoInfo = Depends(get_repo_info),
handler: GithubHandler = Depends(get_github_handler),
) -> None:
async with bot.as_installation(installation_id):
pull_number = event.payload.pull_request.number

# 如果有冲突的话,不会触发 Github Actions
# 所以直接合并即可
await bot.rest.pulls.async_merge(
**repo_info.model_dump(),
pull_number=event.payload.pull_request.number,
merge_method="rebase",
)
await handler.merge_pull_request(pull_number, "rebase")

logger.info(f"已自动合并 #{event.payload.pull_request.number}")
32 changes: 7 additions & 25 deletions src/plugins/github/plugins/publish/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
get_labels_name,
get_related_issue_handler,
get_related_issue_number,
get_repo_info,
install_pre_commit_hooks,
is_bot_triggered_workflow,
)
from src.plugins.github.models import GithubHandler, IssueHandler, RepoInfo
from src.plugins.github.models import GithubHandler, IssueHandler
from src.providers.validation.models import PublishType, ValidationDict

from .depends import (
Expand All @@ -44,7 +43,6 @@
ensure_issue_plugin_test_button,
ensure_issue_plugin_test_button_in_progress,
process_pull_request,
resolve_conflict_pull_requests,
trigger_registry_update,
)
from .validation import (
Expand Down Expand Up @@ -126,7 +124,6 @@ async def handle_publish_plugin_check(
# 确保插件重测按钮存在
await ensure_issue_plugin_test_button(handler)

state["handler"] = handler
state["validation"] = result


Expand All @@ -149,7 +146,6 @@ async def handle_adapter_publish_check(
# 仅在通过检查的情况下创建拉取请求
result = await validate_adapter_info_from_issue(handler.issue)

state["handler"] = handler
state["validation"] = result


Expand Down Expand Up @@ -264,30 +260,16 @@ async def review_submitted_rule(
)


@auto_merge_matcher.handle(
parameterless=[Depends(bypass_git), Depends(install_pre_commit_hooks)]
)
@auto_merge_matcher.handle(parameterless=[Depends(bypass_git)])
async def handle_auto_merge(
bot: GitHubBot,
event: PullRequestReviewSubmitted,
installation_id: int = Depends(get_installation_id),
repo_info: RepoInfo = Depends(get_repo_info),
handler: GithubHandler = Depends(get_github_handler),
) -> None:
async with bot.as_installation(installation_id):
pull_request = (
await bot.rest.pulls.async_get(
**repo_info.model_dump(), pull_number=event.payload.pull_request.number
)
).parsed_data

if not pull_request.mergeable:
# 尝试处理冲突
await resolve_conflict_pull_requests(handler, [pull_request])

await bot.rest.pulls.async_merge(
**repo_info.model_dump(),
pull_number=event.payload.pull_request.number,
merge_method="rebase",
)
logger.info(f"已自动合并 #{event.payload.pull_request.number}")
pull_number = event.payload.pull_request.number

await handler.merge_pull_request(pull_number, "rebase")

logger.info(f"已自动合并 #{pull_number}")
Loading

0 comments on commit c4848fd

Please sign in to comment.