diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 518aded6..979d11a8 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -14,13 +14,15 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number -from codeflash.code_utils.git_utils import get_repo_owner_and_name +from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir +from codeflash.github.PrComment import FileDiffContent, PrComment from codeflash.version import __version__ if TYPE_CHECKING: from requests import Response - from codeflash.github.PrComment import FileDiffContent, PrComment + from codeflash.result.explanation import Explanation + from packaging import version if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local": @@ -175,6 +177,57 @@ def create_pr( return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) +def create_staging( + original_code: str, + new_code: str, + explanation: Explanation, + existing_tests_source: str, + generated_original_test_source: str, + function_trace_id: str, + coverage_message: str, +) -> Response: + """Create a staging pull request, targeting the specified branch. (usually 'staging'). + + :param owner: The owner of the repository. + :param repo: The name of the repository. + :param base_branch: The base branch to target. + :param file_changes: A dictionary of file changes. + :param pr_comment: The pull request comment object, containing the optimization explanation, best runtime, etc. + :param generated_tests: The generated tests. + :return: The response object. + """ + # convert Path objects to strings + relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix() + + build_file_changes = { + Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent( + oldContent=original_code[p], newContent=new_code[p] + ) + for p in original_code + } + payload = { + "baseBranch": get_current_branch(), + "diffContents": build_file_changes, + "prCommentFields": PrComment( + optimization_explanation=explanation.explanation_message(), + best_runtime=explanation.best_runtime_ns, + original_runtime=explanation.original_runtime_ns, + function_name=explanation.function_name, + relative_file_path=relative_path, + speedup_x=explanation.speedup_x, + speedup_pct=explanation.speedup_pct, + winning_behavioral_test_results=explanation.winning_behavioral_test_results, + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + benchmark_details=explanation.benchmark_details, + ).to_json(), + "existingTests": existing_tests_source, + "generatedTests": generated_original_test_source, + "traceId": function_trace_id, + "coverage_message": coverage_message, + } + return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) + + def is_github_app_installed_on_repo(owner: str, repo: str) -> bool: """Check if the Codeflash GitHub App is installed on the specified repository. diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index c6aaebfe..8fdadafc 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -47,6 +47,7 @@ def parse_args() -> Namespace: parser.add_argument( "--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally." ) + parser.add_argument("--staging-review", action="store_true", help="Upload optimizations to staging for review") parser.add_argument( "--verify-setup", action="store_true", diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 82cf4bc5..e7ed2430 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -19,7 +19,7 @@ from rich.tree import Tree from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient -from codeflash.api.cfapi import add_code_context_hash, mark_optimization_success +from codeflash.api.cfapi import add_code_context_hash, create_staging, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils @@ -997,64 +997,99 @@ def find_and_process_best_optimization( original_code_combined[explanation.file_path] = self.function_to_optimize_source_code new_code_combined = new_helper_code.copy() new_code_combined[explanation.file_path] = new_code - if not self.args.no_pr: - coverage_message = ( - original_code_baseline.coverage_results.build_message() - if original_code_baseline.coverage_results - else "Coverage data not available" - ) - generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove - ) - original_runtime_by_test = ( - original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() - ) - optimized_runtime_by_test = ( - best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() - ) - # Add runtime comments to generated tests before creating the PR - generated_tests = add_runtime_comments_to_generated_tests( - self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test - ) - generated_tests_str = "\n\n".join( - [test.generated_original_test_source for test in generated_tests.generated_tests] - ) - existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), - function_to_all_tests, - test_cfg=self.test_cfg, - original_runtimes_all=original_runtime_by_test, - optimized_runtimes_all=optimized_runtime_by_test, - ) - if concolic_test_str: - generated_tests_str += "\n\n" + concolic_test_str - - check_create_pr( - original_code=original_code_combined, - new_code=new_code_combined, - explanation=explanation, - existing_tests_source=existing_tests, - generated_original_test_source=generated_tests_str, - function_trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - coverage_message=coverage_message, - git_remote=self.args.git_remote, - ) - if self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function): - self.write_code_and_helpers( - self.function_to_optimize_source_code, - original_helper_code, - self.function_to_optimize.file_path, - ) - else: - # Mark optimization success since no PR will be created - mark_optimization_success( - trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None - ) + self.process_review( + original_code_baseline, + best_optimization, + generated_tests, + test_functions_to_remove, + concolic_test_str, + original_code_combined, + new_code_combined, + explanation, + function_to_all_tests, + exp_type, + original_helper_code, + ) self.log_successful_optimization(explanation, generated_tests, exp_type) return best_optimization + def process_review( + self, + original_code_baseline: OriginalCodeBaseline, + best_optimization: BestOptimization, + generated_tests: GeneratedTestsList, + test_functions_to_remove: list[str], + concolic_test_str: str | None, + original_code_combined: dict[Path, str], + new_code_combined: dict[Path, str], + explanation: Explanation, + function_to_all_tests: dict[str, set[FunctionCalledInTest]], + exp_type: str, + original_helper_code: dict[Path, str], + ) -> None: + coverage_message = ( + original_code_baseline.coverage_results.build_message() + if original_code_baseline.coverage_results + else "Coverage data not available" + ) + + generated_tests = remove_functions_from_generated_tests( + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + ) + + original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() + optimized_runtime_by_test = ( + best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() + ) + + generated_tests = add_runtime_comments_to_generated_tests( + self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test + ) + + generated_tests_str = "\n\n".join( + [test.generated_original_test_source for test in generated_tests.generated_tests] + ) + if concolic_test_str: + generated_tests_str += "\n\n" + concolic_test_str + + existing_tests = existing_tests_source_for( + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + function_to_all_tests, + test_cfg=self.test_cfg, + original_runtimes_all=original_runtime_by_test, + optimized_runtimes_all=optimized_runtime_by_test, + ) + + data = { + "original_code": original_code_combined, + "new_code": new_code_combined, + "explanation": explanation, + "existing_tests_source": existing_tests, + "generated_original_test_source": generated_tests_str, + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + "coverage_message": coverage_message, + } + + if not self.args.no_pr and not self.args.staging_review: + data["git_remote"] = self.args.git_remote + check_create_pr(**data) + elif self.args.staging_review: + create_staging(**data) + else: + # Mark optimization success since no PR will be created + mark_optimization_success( + trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None + ) + + if ((not self.args.no_pr) or not self.args.staging_review) and ( + self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function) + ): + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + def establish_original_code_baseline( self, code_context: CodeOptimizationContext,