diff --git a/custom_context.xsd b/custom_context.xsd new file mode 100644 index 0000000..d3d0ee5 --- /dev/null +++ b/custom_context.xsd @@ -0,0 +1,190 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/autofic_core/cli.py b/src/autofic_core/cli.py index aea77af..1c58b20 100644 --- a/src/autofic_core/cli.py +++ b/src/autofic_core/cli.py @@ -1,35 +1,122 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Optional + import click -from autofic_core.app import AutoFiCApp -SAST_TOOL_CHOICES = ['semgrep', 'codeql', 'snykcode'] +from autofic_core.pipeline import AutoFiCPipeline + -@click.command() -@click.option('--explain', is_flag=True, help="Print AutoFiC usage guide.") -@click.option('--repo', required=False, help="Target GitHub repository URL to analyze (required).") -@click.option('--save-dir', required=False, default="artifacts/downloaded_repo", help="Directory to save analysis results.") +def _echo_cfg(repo_url, save_dir, sast_tool, run_llm, llm_retry, run_patch, run_pr, xsd_path): + click.echo("AutoFiC - launching pipeline with the following options:\n") + click.echo(f" Repo URL : {repo_url}") + click.echo(f" Save dir : {save_dir}") + click.echo(f" SAST tool : {sast_tool}") + click.echo(f" LLM enabled : {run_llm} (retry={llm_retry})") + click.echo(f" Patch enabled : {run_patch}") + click.echo(f" PR enabled : {run_pr}") + click.echo(f" XSD path : {xsd_path if xsd_path else '(none)'}") + click.echo("") + + +@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.option( + "--repo", + "repo_url", + required=True, + help="Target GitHub repository URL. e.g., https://github.com/org/project", +) +@click.option( + "--save-dir", + type=click.Path(path_type=Path, file_okay=False, dir_okay=True, writable=True), + default=Path("./artifacts"), + show_default=True, + help="Directory to store outputs (snippets, XML, LLM responses, patches).", +) @click.option( - '--sast', - type=click.Choice(SAST_TOOL_CHOICES, case_sensitive=False), - required=False, - help='Select SAST tool to use (choose one of: semgrep, codeql, snykcode).' + "--sast", + "sast_tool", + type=click.Choice(["semgrep", "codeql", "snykcode"], case_sensitive=False), + default="semgrep", + show_default=True, + help="Choose SAST tool (legacy style): --sast .", ) -@click.option('--llm', is_flag=True, help="Run LLM to fix vulnerable code and save responses.") -@click.option('--llm-retry', is_flag=True, help="Re-run LLM for final verification and fixes.") -@click.option('--patch', is_flag=True, help="Generate diffs and apply patches using git.") -@click.option('--pr', is_flag=True, help="Automatically create a pull request.") +@click.option( + "--llm/--no-llm", + "run_llm", + default=True, + show_default=True, + help="Run LLM stage after SAST & XML.", +) +@click.option( + "--retry", + "llm_retry", + is_flag=True, + default=False, + help="Use retry directories (retry_llm/retry_parsed/retry_patch).", +) +@click.option( + "--patch/--no-patch", + "run_patch", + default=True, + show_default=True, + help="Generate unified diffs and attempt to apply patches.", +) +@click.option( + "--pr/--no-pr", + "run_pr", + default=False, + show_default=True, + help="(Reserved) Create a PR automatically. (Not wired in this CLI)", +) +@click.option( + "--xsd", + "xsd_path", + type=click.Path(path_type=Path, exists=True, dir_okay=False), + default=None, + help="Path to custom_context.xsd for XML validation (optional). If omitted, CLI will look for ./custom_context.xsd.", +) +def main( + repo_url: str, + save_dir: Path, + sast_tool: str, + run_llm: bool, + llm_retry: bool, + run_patch: bool, + run_pr: bool, + xsd_path: Optional[Path], +) -> None: + """ + AutoFiC — run pipeline end-to-end (legacy CLI). + + Examples: + python -m autofic_core.cli --repo https://github.com/org/project --sast semgrep + python -m autofic_core.cli --repo https://github.com/org/project --sast codeql --no-llm + """ + # Normalize paths + save_dir = save_dir.expanduser().resolve() + + # Default XSD at project root (optional) + if xsd_path is None: + local_xsd = Path("custom_context.xsd").resolve() + if local_xsd.exists(): + xsd_path = local_xsd -def main(explain, repo, save_dir, sast, llm, llm_retry, patch, pr): - app = AutoFiCApp( - explain=explain, - repo=repo, + _echo_cfg(repo_url, save_dir, sast_tool, run_llm, llm_retry, run_patch, run_pr, xsd_path) + + pipe = AutoFiCPipeline( + repo_url=repo_url, save_dir=save_dir, - sast=sast, - llm=llm, + sast=True, + sast_tool=sast_tool.lower(), + llm=run_llm, llm_retry=llm_retry, - patch=patch, - pr=pr + patch=run_patch, + pr=run_pr, ) - app.run() + pipe.run() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/autofic_core/llm/llm_runner.py b/src/autofic_core/llm/llm_runner.py index 2fb734a..b6bf883 100644 --- a/src/autofic_core/llm/llm_runner.py +++ b/src/autofic_core/llm/llm_runner.py @@ -1,107 +1,132 @@ -# ============================================================================= -# Copyright 2025 AutoFiC Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= +from __future__ import annotations import os -import click +import re +import time +from dataclasses import dataclass from pathlib import Path -from openai import OpenAI -from typing import Any -from dotenv import load_dotenv +from typing import Any, List, Optional, Tuple + +from openai import OpenAI, APIConnectionError, RateLimitError +from openai.types.chat import ChatCompletion + from autofic_core.errors import LLMExecutionError -from autofic_core.sast.merger import merge_snippets_by_file -from autofic_core.llm.prompt_generator import PromptGenerator -load_dotenv() -client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) +DEFAULT_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") +DEFAULT_TEMPERATURE = float(os.getenv("OPENAI_TEMPERATURE", "0.2")) +MAX_RETRIES = int(os.getenv("OPENAI_MAX_RETRIES", "3")) +RETRY_BACKOFF = float(os.getenv("OPENAI_RETRY_BACKOFF", "2.0")) -class LLMRunner: - """ - Run LLM with a given prompt. - """ - def __init__(self, model="gpt-4o"): - self.model = model - - def run(self, prompt: str) -> str: - """ - Run prompt and return response. - Raises: - LLMExecutionError: On OpenAI error - """ - try: - response = client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a security code fixer."}, - {"role": "user", "content": prompt} - ], - temperature=0.3 - ) - return response.choices[0].message.content.strip() - except Exception as e: - raise LLMExecutionError(str(e)) - - -def save_md_response(content: str, prompt_obj: Any, output_dir: Path) -> str: - """ - Save response to a markdown file. - Returns: - Path: Saved file path - """ - output_dir.mkdir(parents=True, exist_ok=True) +@dataclass +class Prompt: + prompt: str + file_path: str + prompt_id: Optional[str] = None + + +def _client() -> OpenAI: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise LLMExecutionError("[LLM] OPENAI_API_KEY is not set") + return OpenAI(api_key=api_key) + + +def _mk_messages(user_content: str) -> List[dict]: + sys = os.getenv( + "OPENAI_SYSTEM_PROMPT", + "You are a helpful assistant that writes minimal, correct code patches.", + ) + return [ + {"role": "system", "content": sys}, + {"role": "user", "content": user_content}, + ] + + +def _extract_text(resp: ChatCompletion) -> str: try: - path = Path(prompt_obj.snippet.path if hasattr(prompt_obj, "snippet") else prompt_obj.path) - except Exception as e: - raise RuntimeError(f"[ERROR] Failed to resolve output path: {e}") - - parts = [p for p in path.parts if p not in ("artifacts", "downloaded_repo")] - flat_path = "_".join(parts) - output_path = output_dir / f"response_{flat_path}.md" - - output_path.write_text(content, encoding="utf-8") - return output_path - - -def run_llm_for_semgrep_results( - semgrep_json_path: str, - output_dir: Path, - tool: str = "semgrep", - model: str = "gpt-4o", -) -> None: - """ - Run LLM for all prompts from a SAST result. - """ - if tool == "semgrep": - from autofic_core.sast.semgrep.preprocessor import SemgrepPreprocessor as Preprocessor - elif tool == "codeql": - from autofic_core.sast.codeql.preprocessor import CodeQLPreprocessor as Preprocessor - elif tool == "snykcode": - from autofic_core.sast.snykcode.preprocessor import SnykCodePreprocessor as Preprocessor - else: - raise ValueError(f"Unsupported SAST tool: {tool}") - - raw_snippets = Preprocessor.preprocess(semgrep_json_path) - merged_snippets = merge_snippets_by_file(raw_snippets) - prompts = PromptGenerator().generate_prompts(merged_snippets) - runner = LLMRunner(model=model) - - for prompt in prompts: - try: - result = runner.run(prompt.prompt) - save_md_response(result, prompt, output_dir) - except LLMExecutionError: - continue \ No newline at end of file + return resp.choices[0].message.content or "" + except Exception: + return "" + + +_filename_sanitize_re = re.compile(r"[^a-zA-Z0-9_.-]+") + + +def _safe_filename(s: str) -> str: + s = s.replace("/", "_").replace("\\", "_").replace(":", "_") + s = _filename_sanitize_re.sub("_", s) + return s.strip("_") or "unknown" + + +def _get_prompt_meta(prompt_obj: Any) -> Tuple[str, str]: + file_path = getattr(prompt_obj, "file_path", None) + if file_path is None and isinstance(prompt_obj, dict): + file_path = prompt_obj.get("file_path") + file_path = file_path or "unknown" + + pid = getattr(prompt_obj, "prompt_id", None) + if pid is None: + pid = getattr(prompt_obj, "id", None) + if pid is None: + pid = getattr(prompt_obj, "uid", None) + if pid is None and isinstance(prompt_obj, dict): + pid = prompt_obj.get("prompt_id") or prompt_obj.get("id") or prompt_obj.get("uid") + if pid is None: + pid = str(time.time_ns()) + + return str(file_path), str(pid) + + +def save_md_response(text: str, prompt_obj: Any, output_dir: Path) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + + file_path, _ = _get_prompt_meta(prompt_obj) + # ex) routes/app.js -> app + base_stem = _safe_filename(Path(file_path).stem) or "response" + name = f"{base_stem}.md" + path = output_dir / name + + if path.exists(): + i = 1 + while True: + candidate = output_dir / f"{base_stem}_{i}.md" + if not candidate.exists(): + path = candidate + break + i += 1 + + with path.open("w", encoding="utf-8") as f: + f.write(text) + return path + + +class LLMRunner: + def __init__(self, model: Optional[str] = None, temperature: Optional[float] = None): + self.model = model or DEFAULT_MODEL + self.temperature = DEFAULT_TEMPERATURE if temperature is None else temperature + self.client = _client() + + def run(self, user_prompt: str) -> str: + last_err: Optional[Exception] = None + for attempt in range(1, MAX_RETRIES + 1): + try: + resp = self.client.chat.completions.create( + model=self.model, + messages=_mk_messages(user_prompt), + temperature=self.temperature, + ) + text = _extract_text(resp) + if not text.strip(): + raise LLMExecutionError("[LLM] Empty response") + return text + except (RateLimitError, APIConnectionError) as e: + last_err = e + if attempt >= MAX_RETRIES: + break + time.sleep(RETRY_BACKOFF * attempt) + except Exception as e: + last_err = e + break + raise LLMExecutionError(f"[LLM] call failed: {last_err}") diff --git a/src/autofic_core/llm/prompt_generator.py b/src/autofic_core/llm/prompt_generator.py index c9125bd..4e7998b 100644 --- a/src/autofic_core/llm/prompt_generator.py +++ b/src/autofic_core/llm/prompt_generator.py @@ -1,126 +1,119 @@ -# ============================================================================= -# Copyright 2025 AutoFiC Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -from typing import List -from pydantic import BaseModel -from autofic_core.sast.snippet import BaseSnippet -from autofic_core.errors import ( - PromptGenerationException, - PromptGeneratorErrorCodes, - PromptGeneratorErrorMessages, -) - - -class PromptTemplate(BaseModel): - title: str - content: str - - def render(self, file_snippet: BaseSnippet) -> str: - """Render a prompt based on the provided code snippet.""" - if not file_snippet.input.strip(): - raise PromptGenerationException( - PromptGeneratorErrorCodes.EMPTY_SNIPPET, - PromptGeneratorErrorMessages.EMPTY_SNIPPET, - ) +from __future__ import annotations - vulnerabilities_str = ( - f"Type: {', '.join(file_snippet.vulnerability_class) or 'Unknown'}\n" - f"CWE: {', '.join(file_snippet.cwe) or 'N/A'}\n" - f"Description: {file_snippet.message or 'None'}\n" - f"Severity: {file_snippet.severity or 'Unknown'}\n" - f"Location: {file_snippet.start_line} ~ {file_snippet.end_line} (Only modify this code range)\n\n" - ) - - try: - return self.content.format( - input=file_snippet.input, - vulnerabilities=vulnerabilities_str, - ) - except Exception: - raise PromptGenerationException( - PromptGeneratorErrorCodes.TEMPLATE_RENDER_ERROR, - PromptGeneratorErrorMessages.TEMPLATE_RENDER_ERROR, - ) +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Optional +import textwrap +import json +from autofic_core.sast.snippet import BaseSnippet -class GeneratedPrompt(BaseModel): - title: str + +@dataclass +class PromptItem: + """Container passed to LLMRunner and save_md_response.""" + file_path: str prompt: str - snippet: BaseSnippet + meta: Dict[str, str] class PromptGenerator: - def __init__(self): - self.template = PromptTemplate( - title="Refactoring Vulnerable Code Snippet (File Level)", - content=( - "The following is a JavaScript source file that contains security vulnerabilities.\n\n" - "```javascript\n" - "{input}\n" - "```\n\n" - "Detected vulnerabilities:\n\n" - "{vulnerabilities}" - "Please strictly follow the guidelines below when modifying the cozde:\n" - "- Modify **only the vulnerable parts** of the file with **minimal changes**.\n" - "- Preserve the **original line numbers, indentation, and code formatting** exactly.\n" - "- **Do not modify any part of the file that is unrelated to the vulnerabilities.**\n" - "- Output the **entire file**, not just the changed lines.\n" - "- This code will be used for diff-based automatic patching, so structural changes may cause the patch to fail.\n\n" - "Output format example:\n" - "1. Vulnerability Description: ...\n" - "2. Potential Risk: ...\n" - "3. Recommended Fix: ...\n" - "4. Final Modified Code:\n" - "```javascript\n" - "// Entire file content, but only vulnerable parts should be modified minimally\n" - "...entire code...\n" - "```\n" - "5. Additional Notes: (optional)\n" - ), - ) - - def generate_prompt(self, file_snippet: BaseSnippet) -> GeneratedPrompt: - """Generate a single prompt from one code snippet.""" - if not isinstance(file_snippet, BaseSnippet): - raise TypeError(f"[ ERROR ] generate_prompt: Invalid input type: {type(file_snippet)}") - rendered_prompt = self.template.render(file_snippet) - return GeneratedPrompt( - title=self.template.title, - prompt=rendered_prompt, - snippet=file_snippet, - ) - - def generate_prompts(self, file_snippets: List[BaseSnippet]) -> List[GeneratedPrompt]: - """Generate prompts from multiple snippets.""" - prompts = [] - for idx, snippet in enumerate(file_snippets): - if isinstance(snippet, dict): - snippet = BaseSnippet(**snippet) - elif not isinstance(snippet, BaseSnippet): - raise TypeError(f"[ ERROR ] generate_prompts: Invalid type at index {idx}: {type(snippet)}") - prompts.append(self.generate_prompt(snippet)) + """ + Generate prompts grouped by file from merged SAST snippets. + - Respects Team-Atlanta XML if present: `/sast/CUSTOM_CONTEXT.xml` + - Produces one PromptItem per *file* (aggregating all issues in that file) + """ + + def __init__(self, save_dir: Optional[Path] = None) -> None: + self.save_dir = Path(save_dir) if save_dir else Path(".") + + # ----------------------------- public API ----------------------------- + def generate_prompts(self, file_snippets: List[BaseSnippet]) -> List[PromptItem]: + """Group snippets by file and generate prompts.""" + by_file = self._group_by_file(file_snippets) + xml_path = self._find_custom_context_xml() + + prompts: List[PromptItem] = [] + for file_path, items in by_file.items(): + content = self._render_prompt_for_file(file_path, items, xml_path) + prompts.append( + PromptItem( + file_path=file_path, + prompt=content, + meta={"xml_path": str(xml_path) if xml_path else "", "issues_count": str(len(items))}, + ) + ) return prompts def get_unique_file_paths(self, file_snippets: List[BaseSnippet]) -> List[str]: - """Extract unique paths from list of snippets.""" - paths = set() - for idx, snippet in enumerate(file_snippets): - if isinstance(snippet, dict): - snippet = BaseSnippet(**snippet) - elif not isinstance(snippet, BaseSnippet): - raise TypeError(f"[ ERROR ] get_unique_file_paths: Type error at index {idx}: {type(snippet)}") - paths.add(snippet.path) - return sorted(paths) \ No newline at end of file + """Used by pipeline for summary display.""" + return sorted({sn.path for sn in file_snippets if getattr(sn, "path", None)}) + + # ---------------------------- helpers -------------------------------- + def _group_by_file(self, snippets: Iterable[BaseSnippet]) -> Dict[str, List[BaseSnippet]]: + grouped: Dict[str, List[BaseSnippet]] = {} + for sn in snippets: + path = getattr(sn, "path", "") or "" + grouped.setdefault(path, []).append(sn) + # stable order by start_line + for k in grouped: + grouped[k].sort(key=lambda s: (s.start_line or 0, s.end_line or 0)) + return grouped + + def _find_custom_context_xml(self) -> Optional[Path]: + """Look for `/sast/CUSTOM_CONTEXT.xml`.""" + candidate = self.save_dir / "sast" / "CUSTOM_CONTEXT.xml" + return candidate if candidate.exists() else None + + def _issues_section(self, items: List[BaseSnippet]) -> str: + lines: List[str] = [] + for sn in items: + sev = sn.bit_severity or sn.severity or "INFO" + rng = f"{sn.start_line}–{sn.end_line}" + trig = (sn.bit_trigger or sn.message or "").strip() + cwe = ", ".join(sn.cwe) if sn.cwe else "" + cwe_part = f" (CWE: {cwe})" if cwe else "" + lines.append(f"- [{sev}] lines {rng} — {trig}{cwe_part}") + return "\n".join(lines) + + def _render_prompt_for_file(self, file_path: str, items: List[BaseSnippet], xml_path: Optional[Path]) -> str: + """ + Compose the final prompt text for a single file. + Note: Keep instructions crisp for diff-only outputs. + """ + issues = self._issues_section(items) + xml_hint = f"\nTeam-Atlanta context XML is available at: {xml_path}\nUse it to confirm BIT (Trigger, Steps, Reproduction, Severity)." if xml_path else "" + + policy = textwrap.dedent(""" + Output policy: + - Output ONLY unified diff for the repository file(s). + - Do NOT include code fences, prose, file headers, or explanations. + - Keep changes minimal, targeted to fix the vulnerabilities. + - Do NOT change behavior beyond necessary security fixes. + - Never replace 'http' with 'https' unless explicitly required by the issue. + """).strip() + + file_intro = textwrap.dedent(f""" + You are given merged SAST findings for a single file. + + Target file: {file_path} + Issues: + {issues} + """).strip() + + # Optional: provide a tiny JSON with ranges to help a model focus. + focus_json = { + "file": file_path, + "ranges": [{"start": s.start_line, "end": s.end_line, "severity": (s.bit_severity or s.severity or "INFO")} + for s in items] + } + + prompt = f"""{file_intro} +{xml_hint} + +Focus hints (JSON): +{json.dumps(focus_json, ensure_ascii=False)} + +{policy} +""" + return prompt diff --git a/src/autofic_core/llm/response_parser.py b/src/autofic_core/llm/response_parser.py index 98ba035..8f9248d 100644 --- a/src/autofic_core/llm/response_parser.py +++ b/src/autofic_core/llm/response_parser.py @@ -106,4 +106,4 @@ def extract_and_save_all(self) -> bool: success = True except ResponseParseError: continue - return success + return success \ No newline at end of file diff --git a/src/autofic_core/pipeline.py b/src/autofic_core/pipeline.py index 0b9e0d6..f4ccb06 100644 --- a/src/autofic_core/pipeline.py +++ b/src/autofic_core/pipeline.py @@ -22,6 +22,8 @@ from autofic_core.sast.snykcode.runner import SnykCodeRunner from autofic_core.sast.snykcode.preprocessor import SnykCodePreprocessor from autofic_core.sast.merger import merge_snippets_by_file +# ⬇️ 여기만 변경 +from autofic_core.sast.xml_generator import render_custom_context, RenderOptions from autofic_core.llm.prompt_generator import PromptGenerator from autofic_core.llm.llm_runner import LLMRunner, save_md_response @@ -58,13 +60,14 @@ def clone(self): console.print("[ SUCCESS ] Fork completed\n", style="bold green") self.clone_path = Path( - self.handler.clone_repo(save_dir=str(self.save_dir), use_forked=self.handler.needs_fork)) + self.handler.clone_repo(save_dir=str(self.save_dir), use_forked=self.handler.needs_fork) + ) console.print(f"[ SUCCESS ] Repository cloned successfully: {self.clone_path}", style="bold green") - except ForkFailedError as e: + except ForkFailedError: sys.exit(1) - except RepoAccessError as e: + except RepoAccessError: raise except (PermissionError, OSError) as e: @@ -77,7 +80,7 @@ def __init__(self, repo_path: Path, save_dir: Path): self.save_dir = save_dir def run(self): - description = "Running Semgrep...".ljust(28) + description = "Running Semgrep...".ljust(28) with create_progress() as progress: task = progress.add_task(description, total=100) @@ -103,17 +106,35 @@ def _post_process(self, data): sast_dir.mkdir(parents=True, exist_ok=True) before_path = sast_dir / "before.json" SemgrepPreprocessor.save_json_file(data, before_path) + snippets = SemgrepPreprocessor.preprocess(str(before_path), str(self.repo_path)) merged = merge_snippets_by_file(snippets) + merged_path = sast_dir / "merged_snippets.json" with open(merged_path, "w", encoding="utf-8") as f: json.dump([s.model_dump() for s in merged], f, indent=2, ensure_ascii=False) + # Team-Atlanta XML 렌더링 + (선택) XSD 검증 + try: + xml_out = sast_dir / "CUSTOM_CONTEXT.xml" + schema_path = Path("schemas/custom_context.xsd").resolve() + os.environ["AUTOFIC_LAST_SAST_TOOL"] = "semgrep" + render_custom_context( + merged_snippets=merged, + output_path=xml_out, + schema_path=schema_path if schema_path.exists() else None, + options=RenderOptions(schema_location="schemas/custom_context.xsd") + ) + console.print(f"[ SUCCESS ] CUSTOM_CONTEXT.xml generated → {xml_out}", style="bold green") + except Exception as e: + console.print(f"[ WARN ] CUSTOM_CONTEXT.xml generation/validation skipped: {e}", style="yellow") + if not merged: console.print("\n[ INFO ] No vulnerabilities found.\n", style="yellow") console.print( "AutoFiC automation has been halted.--llm, --patch, and --pr stages will not be executed.\n", - style="yellow") + style="yellow", + ) return None return merged_path @@ -125,10 +146,10 @@ def __init__(self, repo_path: Path, save_dir: Path): self.save_dir = save_dir def run(self): - description = "Running CodeQL...".ljust(28) + description = "Running CodeQL...".ljust(28) with create_progress() as progress: task = progress.add_task(description, total=100) - + start = time.time() runner = CodeQLRunner(repo_path=str(self.repo_path)) result_path = runner.run_codeql() @@ -150,17 +171,34 @@ def _post_process(self, data): sast_dir.mkdir(parents=True, exist_ok=True) before_path = sast_dir / "before.json" CodeQLPreprocessor.save_json_file(data, before_path) + snippets = CodeQLPreprocessor.preprocess(str(before_path), str(self.repo_path)) merged = merge_snippets_by_file(snippets) + merged_path = sast_dir / "merged_snippets.json" with open(merged_path, "w", encoding="utf-8") as f: json.dump([s.model_dump() for s in merged], f, indent=2, ensure_ascii=False) + try: + xml_out = sast_dir / "CUSTOM_CONTEXT.xml" + schema_path = Path("schemas/custom_context.xsd").resolve() + os.environ["AUTOFIC_LAST_SAST_TOOL"] = "codeql" + render_custom_context( + merged_snippets=merged, + output_path=xml_out, + schema_path=schema_path if schema_path.exists() else None, + options=RenderOptions(schema_location="schemas/custom_context.xsd") + ) + console.print(f"[ SUCCESS ] CUSTOM_CONTEXT.xml generated → {xml_out}", style="bold green") + except Exception as e: + console.print(f"[ WARN ] CUSTOM_CONTEXT.xml generation/validation skipped: {e}", style="yellow") + if not merged: console.print("\n[ INFO ] No vulnerabilities found.\n", style="yellow") console.print( "AutoFiC automation has been halted.--llm, --patch, and --pr stages will not be executed.\n", - style="yellow") + style="yellow", + ) return None return merged_path @@ -170,12 +208,12 @@ class SnykCodeHandler: def __init__(self, repo_path: Path, save_dir: Path): self.repo_path = repo_path self.save_dir = save_dir - + def run(self): - description = "Running SnykCode...".ljust(28) + description = "Running SnykCode...".ljust(28) with create_progress() as progress: task = progress.add_task(description, total=100) - + start = time.time() runner = SnykCodeRunner(repo_path=str(self.repo_path)) result = runner.run_snykcode() @@ -195,17 +233,34 @@ def _post_process(self, data): sast_dir.mkdir(parents=True, exist_ok=True) before_path = sast_dir / "before.json" SnykCodePreprocessor.save_json_file(data, before_path) + snippets = SnykCodePreprocessor.preprocess(str(before_path), str(self.repo_path)) merged = merge_snippets_by_file(snippets) + merged_path = sast_dir / "merged_snippets.json" with open(merged_path, "w", encoding="utf-8") as f: json.dump([s.model_dump() for s in merged], f, indent=2, ensure_ascii=False) + try: + xml_out = sast_dir / "CUSTOM_CONTEXT.xml" + schema_path = Path("schemas/custom_context.xsd").resolve() + os.environ["AUTOFIC_LAST_SAST_TOOL"] = "snykcode" + render_custom_context( + merged_snippets=merged, + output_path=xml_out, + schema_path=schema_path if schema_path.exists() else None, + options=RenderOptions(schema_location="schemas/custom_context.xsd") + ) + console.print(f"[ SUCCESS ] CUSTOM_CONTEXT.xml generated → {xml_out}", style="bold green") + except Exception as e: + console.print(f"[ WARN ] CUSTOM_CONTEXT.xml generation/validation skipped: {e}", style="yellow") + if not merged: console.print("\n[ INFO ] No vulnerabilities found.\n", style="yellow") console.print( "AutoFiC automation has been halted.--llm, --patch, and --pr stages will not be executed.\n", - style="yellow") + style="yellow", + ) return None return merged_path @@ -289,14 +344,13 @@ def run(self): llm = LLMRunner() self.llm_output_dir.mkdir(parents=True, exist_ok=True) - - description = "Generating LLM responses... \n".ljust(28) + + description = "Generating LLM responses... \n".ljust(28) with create_progress() as progress: task = progress.add_task(description, total=len(prompts)) for p in prompts: response = llm.run(p.prompt) - save_md_response(response, p, output_dir=self.llm_output_dir) progress.update(task, advance=1) @@ -381,7 +435,8 @@ def run(self): class AutoFiCPipeline: - def __init__(self, repo_url: str, save_dir: Path, sast: bool, sast_tool: str, llm: bool, llm_retry: bool, patch: bool, pr: bool): + def __init__(self, repo_url: str, save_dir: Path, sast: bool, sast_tool: str, + llm: bool, llm_retry: bool, patch: bool, pr: bool): self.repo_url = repo_url self.save_dir = save_dir.expanduser().resolve() self.sast = sast @@ -424,8 +479,9 @@ def run(self): with open(merged_path, "r", encoding="utf-8") as f: merged_data = json.load(f) - self.llm_processor = LLMProcessor(sast_result_path, self.repo_manager.clone_path, self.save_dir, - self.sast_tool) + self.llm_processor = LLMProcessor( + sast_result_path, self.repo_manager.clone_path, self.save_dir, self.sast_tool + ) try: prompts, file_snippets = self.llm_processor.run() @@ -434,8 +490,10 @@ def run(self): sys.exit(1) if not prompts: - console.print("[ INFO ] No valid prompts returned from LLM processor. Exiting pipeline early.\n", - style="cyan") + console.print( + "[ INFO ] No valid prompts returned from LLM processor. Exiting pipeline early.\n", + style="cyan", + ) sys.exit(0) self.llm_processor.extract_and_save_parsed_code() @@ -450,12 +508,12 @@ def run(self): repo_url=self.repo_url, detected_issues_count=len(unique_file_paths), output_dir=str(llm_output_dir), - response_files=response_files + response_files=response_files, ) if self.patch: parsed_dir = self.save_dir / ("retry_parsed" if self.llm_retry else "parsed") patch_dir = self.save_dir / ("retry_patch" if self.llm_retry else "patch") - + patch_manager = PatchManager(parsed_dir, patch_dir, self.repo_manager.clone_path) - patch_manager.run() \ No newline at end of file + patch_manager.run() diff --git a/src/autofic_core/sast/codeql/preprocessor.py b/src/autofic_core/sast/codeql/preprocessor.py index 5d27bc3..7515d1b 100644 --- a/src/autofic_core/sast/codeql/preprocessor.py +++ b/src/autofic_core/sast/codeql/preprocessor.py @@ -1,147 +1,212 @@ -# ============================================================================= -# Copyright 2025 Autofic Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -""" -CodeQLPreprocessor: Extracts vulnerability snippets from CodeQL SARIF results. - -- Parses SARIF JSON results -- Matches vulnerabilities to code regions -- Generates BaseSnippet objects for downstream processing -""" +from __future__ import annotations import json -import os -from typing import List +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + from autofic_core.sast.snippet import BaseSnippet +@dataclass +class _Loc: + path: str + start: int + end: int + + class CodeQLPreprocessor: """ - Processes SARIF output from CodeQL and extracts vulnerability information - into a uniform BaseSnippet format. + Normalize CodeQL SARIF-like JSON into BaseSnippet list. + + - Coerces message/bit_trigger to *string* even if input is dict({"text": ...}). + - Fills BIT fields (trigger/steps/reproduction/severity) heuristically when absent. + - Extracts source snippet from repository file based on (start,end) lines. """ + # ---------------------- public helpers ---------------------- @staticmethod - def read_json_file(path: str) -> dict: - """Reads JSON content from the given file path.""" - with open(path, 'r', encoding='utf-8') as f: - return json.load(f) + def save_json_file(data: Dict[str, Any], path: Path | str) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) @staticmethod - def save_json_file(data: dict, path: str) -> None: - """Saves the given dictionary as JSON to the specified path.""" - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'w', encoding='utf-8') as f: - json.dump(data, f, indent=4, ensure_ascii=False) + def preprocess(json_path: str, repo_root: str) -> List[BaseSnippet]: + with open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + return CodeQLPreprocessor._parse(data, Path(repo_root)) + # ---------------------- internal utils ---------------------- @staticmethod - def preprocess(input_json_path: str, base_dir: str = ".") -> List[BaseSnippet]: + def _as_text(val: Any) -> str: """ - Parses CodeQL SARIF results and extracts code snippets for each finding. + Normalize CodeQL 'message' objects and other variants to plain string. + Accepts str / dict / list / None. + """ + if val is None: + return "" + if isinstance(val, str): + return val + if isinstance(val, dict): + # common SARIF shape: {"text": "..."} or {"markdown": "..."} + for k in ("text", "markdown", "message", "value"): + if k in val and isinstance(val[k], str): + return val[k] + try: + return json.dumps(val, ensure_ascii=False) + except Exception: + return str(val) + if isinstance(val, (list, tuple)): + try: + return " ".join(CodeQLPreprocessor._as_text(x) for x in val) + except Exception: + return str(val) + return str(val) + + @staticmethod + def _pick_severity(result: Dict[str, Any], rule: Optional[Dict[str, Any]]) -> str: + # CodeQL SARIF often uses result.level: "error" | "warning" | "note" + sev = CodeQLPreprocessor._as_text(result.get("level")).upper() + if not sev: + # sometimes rules have severity in properties + sev = CodeQLPreprocessor._as_text( + (rule or {}).get("properties", {}).get("problem.severity") + ).upper() + # map to common levels + mapping = {"ERROR": "HIGH", "WARNING": "MEDIUM", "NOTE": "LOW"} + return mapping.get(sev, sev or "INFO") - Args: - input_json_path (str): Path to the SARIF result file. - base_dir (str): Base path to resolve relative file URIs. + @staticmethod + def _locations(result: Dict[str, Any]) -> List[_Loc]: + locs: List[_Loc] = [] + for loc in result.get("locations", []) or []: + phys = (loc.get("physicalLocation") or {}) + art = (phys.get("artifactLocation") or {}) + uri = art.get("uri") or art.get("uriBaseId") or "" + region = (phys.get("region") or {}) + start = int(region.get("startLine") or 1) + end = int(region.get("endLine") or start) + if uri: + locs.append(_Loc(path=uri, start=start, end=end)) + # fallback: relatedLocations + if not locs: + for loc in result.get("relatedLocations", []) or []: + phys = (loc.get("physicalLocation") or {}) + art = (phys.get("artifactLocation") or {}) + uri = art.get("uri") or "" + region = (phys.get("region") or {}) + start = int(region.get("startLine") or 1) + end = int(region.get("endLine") or start) + if uri: + locs.append(_Loc(path=uri, start=start, end=end)) + return locs - Returns: - List[BaseSnippet]: Extracted and structured vulnerability snippets. + @staticmethod + def _resolve_rule(result: Dict[str, Any], rules_by_id: Dict[str, Dict[str, Any]]) -> Optional[Dict[str, Any]]: + rule_id = result.get("ruleId") + if rule_id and rule_id in rules_by_id: + return rules_by_id[rule_id] + # sometimes index-based rule + ridx = result.get("ruleIndex") + if isinstance(ridx, int): + for rid, rule in rules_by_id.items(): + if rule.get("_index") == ridx: + return rule + return None + + @staticmethod + def _build_rules_index(run: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: + idx: Dict[str, Dict[str, Any]] = {} + for i, rule in enumerate(run.get("tool", {}).get("driver", {}).get("rules", []) or []): + rid = rule.get("id") or f"rule_{i}" + rule["_index"] = i + idx[str(rid)] = rule + return idx + + @staticmethod + def _read_snippet(repo_root: Path, loc: _Loc) -> Tuple[str, int, int]: + """ + Return (snippet_text, start, end). If file cannot be read, snippet becomes empty string. """ - results = CodeQLPreprocessor.read_json_file(input_json_path) - - # Build rule metadata lookup from SARIF tool section - rule_metadata = {} - for run in results.get("runs", []): - for rule in run.get("tool", {}).get("driver", {}).get("rules", []): - rule_id = rule.get("id") - if rule_id: - rule_metadata[rule_id] = { - "cwe": [ - tag.split("/")[-1].replace("cwe-", "CWE-") - for tag in rule.get("properties", {}).get("tags", []) - if "cwe-" in tag - ], - "references": [rule.get("helpUri")] if rule.get("helpUri") else [], - "level": ( - rule.get("defaultConfiguration", {}).get("level") - or rule.get("properties", {}).get("problem.severity", "UNKNOWN") - ) - } - - processed: List[BaseSnippet] = [] - snippet_idx = 0 - - for run in results.get("runs", []): - for res in run.get("results", []): - location = res.get("locations", [{}])[0].get("physicalLocation", {}) - artifact = location.get("artifactLocation", {}) - region = location.get("region", {}) - - file_uri = artifact.get("uri", "Unknown") - full_path = os.path.join(base_dir, file_uri) - start_line = region.get("startLine", 0) - end_line = region.get("endLine") or start_line - - rule_id = res.get("ruleId") - meta = rule_metadata.get(rule_id, {}) if rule_id else {} - - # Normalize severity level - level = res.get("level") or meta.get("level", "UNKNOWN") - if isinstance(level, list): - level = level[0] if level else "UNKNOWN" - severity = str(level).upper() - - snippet = "" - lines = [] - - try: - if os.path.exists(full_path): - with open(full_path, "r", encoding="utf-8") as code_file: - lines = code_file.readlines() - - # Defensive check on line bounds - if start_line > len(lines) or start_line < 1: - continue - - raw_snippet = ( - lines[start_line - 1:end_line] - if end_line > start_line - else [lines[start_line - 1]] - ) - - if all(not line.strip() for line in raw_snippet): - continue - - snippet = "".join(raw_snippet) - - except Exception: - continue # Skip problematic entries silently - - processed.append(BaseSnippet( - input="".join(lines), - snippet=snippet.strip(), - path=file_uri, - idx=snippet_idx, - start_line=start_line, - end_line=end_line, - message=res.get("message", {}).get("text", ""), - severity=severity, - vulnerability_class=[rule_id.split("/", 1)[-1]] if rule_id else [], - cwe=meta.get("cwe", []), - references=meta.get("references", []) - )) - snippet_idx += 1 - - return processed \ No newline at end of file + file_path = (repo_root / loc.path).resolve() + if not file_path.exists(): + # Sometimes URIs are relative to repo root without normalization + file_path = (repo_root / Path(loc.path.strip("/"))).resolve() + text = "" + try: + with open(file_path, "r", encoding="utf-8", errors="ignore") as f: + lines = f.readlines() + start = max(1, loc.start) + end = min(len(lines), max(loc.end, start)) + text = "".join(lines[start - 1 : end]) + return text, start, end + except Exception: + return "", loc.start, loc.end + + # ---------------------- main parse ---------------------- + @staticmethod + def _parse(data: Dict[str, Any], repo_root: Path) -> List[BaseSnippet]: + snippets: List[BaseSnippet] = [] + + runs = data.get("runs") or [] + for run in runs: + rules_idx = CodeQLPreprocessor._build_rules_index(run) + results = run.get("results") or [] + for res in results: + locs = CodeQLPreprocessor._locations(res) + if not locs: + continue + + rule = CodeQLPreprocessor._resolve_rule(res, rules_idx) + message_str = CodeQLPreprocessor._as_text(res.get("message")) + + # severity + severity = CodeQLPreprocessor._pick_severity(res, rule) + + # CWE tags if present (very vendor-specific; best-effort) + cwe: List[str] = [] + rule_tags = (rule or {}).get("properties", {}).get("tags", []) or [] + for t in rule_tags: + tstr = CodeQLPreprocessor._as_text(t) + if tstr.upper().startswith("CWE-"): + cwe.append(tstr) + + # For each location, create a snippet + for loc in locs: + code, start, end = CodeQLPreprocessor._read_snippet(repo_root, loc) + + # ---- BIT heuristics (string-only) ---- + bit_trigger = message_str # use CodeQL message as trigger by default + bit_steps: List[str] = [] + if start == end: + bit_steps.append(f"Review line {start} in {loc.path}") + else: + bit_steps.append(f"Review lines {start}-{end} in {loc.path}") + bit_repro = "Inspect the indicated code region and verify unsafe data flow / pattern reported by CodeQL." + bit_sev = severity + + # ---- Build BaseSnippet (all strings for message/BIT) ---- + sn = BaseSnippet( + path=str(loc.path), + start_line=int(start), + end_line=int(end), + severity=str(severity), + message=str(message_str), + snippet=str(code), + cwe=cwe, + # BIT + bit_trigger=str(bit_trigger), + bit_steps=[str(s) for s in bit_steps], + bit_reproduction=str(bit_repro), + bit_severity=str(bit_sev), + tool="codeql", + classes=[], + references=[], + tags=[], + sources=["codeql"], + ) + snippets.append(sn) + + return snippets diff --git a/src/autofic_core/sast/merger.py b/src/autofic_core/sast/merger.py index 3809ea8..6a8f9c8 100644 --- a/src/autofic_core/sast/merger.py +++ b/src/autofic_core/sast/merger.py @@ -1,56 +1,167 @@ -# ============================================================================= -# Copyright 2025 AutoFiC Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -from collections import defaultdict -from typing import List -from autofic_core.sast.snippet import BaseSnippet +from __future__ import annotations +from collections import defaultdict, OrderedDict +from dataclasses import dataclass, asdict, field +from typing import List, Dict, Any, Optional, Iterable, Tuple, Set -def merge_snippets_by_file(snippets: List[BaseSnippet]) -> List[BaseSnippet]: - grouped = defaultdict(list) - - for snippet in snippets: - grouped[snippet.path].append(snippet) - - merged_snippets = [] - - for path, group in grouped.items(): - base = group[0] - start_line = min(s.start_line for s in group) - end_line = max(s.end_line for s in group) - - snippet_lines_set = set() - for s in group: - if s.snippet: - snippet_lines_set.update(s.snippet.splitlines()) - merged_snippet_text = "\n".join(sorted(snippet_lines_set)) - - merged_message = " | ".join(sorted(set(s.message for s in group if s.message))) - merged_vuln_class = sorted({vc for s in group for vc in s.vulnerability_class}) - merged_cwe = sorted({c for s in group for c in s.cwe}) - merged_references = sorted({r for s in group for r in s.references}) - - severity_order = {"INFO": 0, "WARNING": 1, "ERROR": 2} - severity = max( - (str(s.severity).upper() for s in group if s.severity), - key=lambda x: severity_order.get(x, -1), - default="" - ) - merged_snippets.append(BaseSnippet( - input=base.input, +# try to import project BaseSnippet; otherwise define a lightweight fallback +try: + from snippet import BaseSnippet # type: ignore +except Exception: + @dataclass + class BaseSnippet: + input: str + idx: Optional[int] = None + start_line: int = 0 + end_line: int = 0 + snippet: Optional[str] = None + message: str = "" + vulnerability_class: List[str] = field(default_factory=list) + cwe: List[str] = field(default_factory=list) + severity: str = "" + references: List[str] = field(default_factory=list) + path: str = "" + # optional BIT/extension fields + bit_trigger: Optional[str] = None + bit_steps: List[str] = field(default_factory=list) + bit_reproduction: Optional[str] = None + bit_severity: Optional[str] = None + constraints: Dict[str, Any] = field(default_factory=dict) + +# severity ranking helper (higher index => more severe) +_SEVERITY_ORDER = ["INFO", "LOW", "MEDIUM", "HIGH", "CRITICAL"] +_SEVERITY_MAP = {s: i for i, s in enumerate(_SEVERITY_ORDER)} + +def _pick_worst_severity(severities: Iterable[str]) -> str: + """Return the worst (highest) severity among given severities. Unknown are treated as INFO.""" + worst_idx = -1 + worst = "" + for s in severities: + if not s: + idx = 0 + else: + idx = _SEVERITY_MAP.get(s.upper(), 0) + if idx > worst_idx: + worst_idx = idx + worst = s + return worst or "" + +def are_ranges_overlapping(a_start: int, a_end: int, b_start: int, b_end: int) -> bool: + """Return True if two inclusive ranges [a_start,a_end] and [b_start,b_end] overlap or touch.""" + return not (a_end < b_start - 1 or b_end < a_start - 1) + +def _unique_preserve_order(items: Iterable[Any]) -> List[Any]: + seen = set() + out = [] + for it in items: + if it is None: + continue + if it not in seen: + seen.add(it) + out.append(it) + return out + +def _merge_constraints(constraints_list: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Merge a list of constraints dicts. In case of key collision, namespace keys with a suffix: + original -> original__1, original__2, ... + """ + merged: Dict[str, Any] = {} + counts: Dict[str, int] = {} + for cdict in constraints_list: + if not cdict: + continue + for k, v in cdict.items(): + if k not in merged: + merged[k] = v + counts[k] = 1 + else: + # collision: create a namespaced key + counts[k] += 1 + merged[f"{k}__{counts[k]}"] = v + return merged + +def merge_group(snippets: List[BaseSnippet]) -> BaseSnippet: + """ + Merge a list of BaseSnippet-like objects that belong to the same file and overlapping range. + Returns a new BaseSnippet (instance of the same class if possible). + """ + if not snippets: + raise ValueError("merge_group called with empty list") + + # sort by start_line for deterministic output + snippets_sorted = sorted(snippets, key=lambda s: (getattr(s, "start_line", 0), getattr(s, "end_line", 0))) + base = snippets_sorted[0] + + path = getattr(base, "path", "") + start_line = min(getattr(s, "start_line", 0) for s in snippets_sorted) + end_line = max(getattr(s, "end_line", 0) for s in snippets_sorted) + + # merge textual snippets (preserve order, unique) + snippet_texts = [getattr(s, "snippet", "") or "" for s in snippets_sorted] + merged_snippet_text = "\n".join(_unique_preserve_order(snippet_texts)).strip() + + # messages + messages = [getattr(s, "message", "") or "" for s in snippets_sorted] + merged_message = " | ".join(_unique_preserve_order(messages)) + + # vuln classes, cwe, references + vuln_classes = [] + cwes = [] + references = [] + inputs = [] + constraints_list = [] + bit_triggers = [] + bit_steps_acc: List[str] = [] + bit_repros = [] + bit_severity_candidates = [] + for s in snippets_sorted: + vuln_classes.extend(getattr(s, "vulnerability_class", []) or []) + cwes.extend(getattr(s, "cwe", []) or []) + references.extend(getattr(s, "references", []) or []) + inputs.append(getattr(s, "input", "")) + constraints_list.append(getattr(s, "constraints", {}) or {}) + # BIT fields (may not exist) + if hasattr(s, "bit_trigger"): + t = getattr(s, "bit_trigger") + if t: + bit_triggers.append(t) + if hasattr(s, "bit_steps"): + steps = getattr(s, "bit_steps") or [] + bit_steps_acc.extend(steps) + if hasattr(s, "bit_reproduction"): + r = getattr(s, "bit_reproduction") + if r: + bit_repros.append(r) + if hasattr(s, "bit_severity"): + bs = getattr(s, "bit_severity") + if bs: + bit_severity_candidates.append(bs) + # severity + merged_vuln_class = _unique_preserve_order(vuln_classes) + merged_cwe = _unique_preserve_order(cwes) + merged_references = _unique_preserve_order(references) + merged_inputs = _unique_preserve_order(inputs) + + merged_constraints = _merge_constraints(constraints_list) + merged_bit_trigger = " | ".join(_unique_preserve_order(bit_triggers)) if bit_triggers else None + merged_bit_steps = _unique_preserve_order(bit_steps_acc) + merged_bit_reproduction = None + if bit_repros: + merged_bit_reproduction = " | ".join(_unique_preserve_order(bit_repros)) + merged_bit_severity = _pick_worst_severity(bit_severity_candidates) if bit_severity_candidates else None + + # severity: pick worst among snippet.severity and bit_severity if provided + severity_candidates = [getattr(s, "severity", "") or "" for s in snippets_sorted] + # include bit severity candidates too (string) + severity_candidates.extend([bs for bs in bit_severity_candidates if bs]) + merged_severity = _pick_worst_severity(severity_candidates) + + # build return instance; prefer using the same class as input if possible + SnippetCls = type(base) + try: + merged = SnippetCls( + input=";".join(merged_inputs), idx=None, start_line=start_line, end_line=end_line, @@ -58,9 +169,123 @@ def merge_snippets_by_file(snippets: List[BaseSnippet]) -> List[BaseSnippet]: message=merged_message, vulnerability_class=merged_vuln_class, cwe=merged_cwe, - severity=severity, + severity=merged_severity, references=merged_references, path=path - )) + ) + # set optional/extension attributes if available + if hasattr(merged, "bit_trigger"): + setattr(merged, "bit_trigger", merged_bit_trigger) + else: + setattr(merged, "bit_trigger", merged_bit_trigger) + if hasattr(merged, "bit_steps"): + setattr(merged, "bit_steps", merged_bit_steps) + else: + setattr(merged, "bit_steps", merged_bit_steps) + if hasattr(merged, "bit_reproduction"): + setattr(merged, "bit_reproduction", merged_bit_reproduction) + else: + setattr(merged, "bit_reproduction", merged_bit_reproduction) + if hasattr(merged, "bit_severity"): + setattr(merged, "bit_severity", merged_bit_severity) + else: + setattr(merged, "bit_severity", merged_bit_severity) + if hasattr(merged, "constraints"): + setattr(merged, "constraints", merged_constraints) + else: + setattr(merged, "constraints", merged_constraints) + except Exception: + # fallback to the dataclass defined earlier + merged = BaseSnippet( + input=";".join(merged_inputs), + idx=None, + start_line=start_line, + end_line=end_line, + snippet=merged_snippet_text, + message=merged_message, + vulnerability_class=merged_vuln_class, + cwe=merged_cwe, + severity=merged_severity, + references=merged_references, + path=path, + bit_trigger=merged_bit_trigger, + bit_steps=merged_bit_steps, + bit_reproduction=merged_bit_reproduction, + bit_severity=merged_bit_severity, + constraints=merged_constraints + ) + + # attach provenance metadata for debugging + try: + setattr(merged, "_merged_from_count", len(snippets_sorted)) + setattr(merged, "_merged_sources", merged_inputs) + except Exception: + pass + + return merged + +def merge_snippets_by_file(snippets: List[BaseSnippet]) -> List[BaseSnippet]: + """ + Merge a list of BaseSnippet objects across files. Snippets are grouped by path and + overlapping line ranges are merged into a single snippet. + + Returns a list of merged BaseSnippet objects. + """ + if not snippets: + return [] + + grouped: Dict[str, List[BaseSnippet]] = defaultdict(list) + for s in snippets: + p = getattr(s, "path", "") or "" + grouped[p].append(s) + + merged_results: List[BaseSnippet] = [] + for path, slist in grouped.items(): + # sort by start_line + ssorted = sorted(slist, key=lambda x: (getattr(x, "start_line", 0), getattr(x, "end_line", 0))) + current_group: List[BaseSnippet] = [] + cur_start = None + cur_end = None + for s in ssorted: + s_start = getattr(s, "start_line", 0) + s_end = getattr(s, "end_line", 0) + if not current_group: + current_group = [s] + cur_start, cur_end = s_start, s_end + continue + if are_ranges_overlapping(cur_start, cur_end, s_start, s_end): + # expand current group range + cur_end = max(cur_end, s_end) + cur_start = min(cur_start, s_start) + current_group.append(s) + else: + # flush group + merged_results.append(merge_group(current_group)) + # start new group + current_group = [s] + cur_start, cur_end = s_start, s_end + # flush last group + if current_group: + merged_results.append(merge_group(current_group)) + + # stable sort by path and start_line for deterministic output + merged_results_sorted = sorted(merged_results, key=lambda x: (getattr(x, "path", ""), getattr(x, "start_line", 0))) + return merged_results_sorted - return merged_snippets \ No newline at end of file +# If executed as script, provide simple demo (no external I/O) +if __name__ == "__main__": + import json + # quick smoke test + a = BaseSnippet(input="a", start_line=10, end_line=12, snippet="foo()", message="vuln A", path="core/app.py", severity="HIGH") + b = BaseSnippet(input="b", start_line=11, end_line=15, snippet="bar()", message="vuln B", path="core/app.py", severity="MEDIUM", bit_trigger="user input", bit_steps=["1. call foo"], constraints={"auth":"none"}) + c = BaseSnippet(input="c", start_line=200, end_line=210, snippet="baz()", message="other", path="utils.py", severity="LOW") + merged = merge_snippets_by_file([a,b,c]) + print(json.dumps([{ + "path": m.path, + "start": m.start_line, + "end": m.end_line, + "severity": getattr(m, "severity", ""), + "bit_trigger": getattr(m, "bit_trigger", None), + "bit_steps": getattr(m, "bit_steps", None), + "constraints": getattr(m, "constraints", None), + } for m in merged], indent=2, ensure_ascii=False)) diff --git a/src/autofic_core/sast/semgrep/preprocessor.py b/src/autofic_core/sast/semgrep/preprocessor.py index 42e8aa6..d88a954 100644 --- a/src/autofic_core/sast/semgrep/preprocessor.py +++ b/src/autofic_core/sast/semgrep/preprocessor.py @@ -1,91 +1,126 @@ -# ============================================================================= -# Copyright 2025 AutoFiC Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= +from __future__ import annotations -import json from pathlib import Path -from typing import List +from typing import Any, Dict, List, Optional, Union +import json + from autofic_core.sast.snippet import BaseSnippet -class SemgrepPreprocessor: +# Semgrep severity → our canonical +_SEMGR_SEV_MAP = { + "ERROR": "HIGH", + "WARNING": "MEDIUM", + "INFO": "LOW", +} + +def _normalize_severity(s: Optional[str]) -> Optional[str]: + if not s: + return None + s = s.strip().upper() + return _SEMGR_SEV_MAP.get(s, s) + +def _safe_get(d: Dict[str, Any], path: List[Union[str, int]], default=None): + cur: Any = d + for key in path: + if isinstance(key, int): + if not isinstance(cur, list) or key >= len(cur): + return default + cur = cur[key] + else: + if not isinstance(cur, dict) or key not in cur: + return default + cur = cur[key] + return cur - @staticmethod - def ensure_list(value): - if value is None: - return [] - if isinstance(value, list): - return value - return [value] +class SemgrepPreprocessor: @staticmethod - def read_json_file(path: str) -> dict: - with open(path, 'r', encoding='utf-8') as f: - return json.load(f) + def save_json_file(data: Dict[str, Any], path: Union[str, Path]) -> None: + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + with p.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) @staticmethod - def save_json_file(data: dict, path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) + def preprocess(json_path: Union[str, Path], repo_root: Union[str, Path]) -> List[BaseSnippet]: + with Path(json_path).open("r", encoding="utf-8") as f: + data = json.load(f) + return SemgrepPreprocessor._parse(data) @staticmethod - def preprocess(input_json_path: str, base_dir: str = ".") -> List[BaseSnippet]: - results = SemgrepPreprocessor.read_json_file(input_json_path) - base_dir_path = Path(base_dir).resolve() - processed: List[BaseSnippet] = [] - - items = results.get("results") if isinstance(results, dict) else results - - for idx, result in enumerate(items): - raw_path = result.get("path", "").strip().replace("\\", "/") - base_dir_str = str(base_dir_path).replace("\\", "/") - - rel_path = raw_path[len(base_dir_str):].lstrip("/") if raw_path.startswith(base_dir_str) else raw_path - - file_path = (base_dir_path / rel_path).resolve() - if not file_path.exists(): - raise FileNotFoundError(f"[ERROR] File not found: {file_path}") - - full_code = file_path.read_text(encoding='utf-8') - - if "start" in result and "line" in result["start"]: - start_line = result["start"]["line"] - end_line = result["end"]["line"] - else: - start_line = result.get("start_line", 0) - end_line = result.get("end_line", 0) - - lines = full_code.splitlines() - snippet_lines = lines[start_line - 1:end_line] if 0 < start_line <= end_line <= len(lines) else [] - snippet = "\n".join(snippet_lines) - - extra = result.get("extra", {}) - meta = extra.get("metadata", {}) - - processed.append(BaseSnippet( - input=full_code, - idx=idx, - start_line=start_line, - end_line=end_line, - snippet=snippet, - message=extra.get("message", ""), - vulnerability_class=SemgrepPreprocessor.ensure_list(meta.get("vulnerability_class")), - cwe=SemgrepPreprocessor.ensure_list(meta.get("cwe")), - severity=extra.get("severity", ""), - references=SemgrepPreprocessor.ensure_list(meta.get("references")), - path=rel_path - )) - - return processed \ No newline at end of file + def _parse(data: Dict[str, Any]) -> List[BaseSnippet]: + results = data.get("results") or [] + out: List[BaseSnippet] = [] + + for idx, r in enumerate(results): + path = r.get("path") or _safe_get(r, ["extra", "path"]) or "" + start_line = _safe_get(r, ["start", "line"], 0) or 0 + end_line = _safe_get(r, ["end", "line"], start_line) or start_line + + message = _safe_get(r, ["extra", "message"]) or "" + severity_raw = _safe_get(r, ["extra", "severity"]) or r.get("severity") + severity = _normalize_severity(severity_raw) + + snippet_text = _safe_get(r, ["extra", "lines"]) or _safe_get( + r, ["extra", "metavars", "metavar", "abstract_content"] + ) + + rule_id = r.get("check_id") or _safe_get(r, ["extra", "engine_kind"]) + vuln_class: List[str] = [] + if rule_id: + parts = str(rule_id).split(".") + vuln_class.append(parts[-1] if len(parts) >= 2 else str(rule_id)) + + # CWE / references + cwe = [] + meta_cwe = _safe_get(r, ["extra", "metadata", "cwe"]) or _safe_get( + r, ["extra", "metadata", "cwe_ids"] + ) or [] + if isinstance(meta_cwe, list): + cwe = [str(x) for x in meta_cwe] + elif isinstance(meta_cwe, str): + cwe = [meta_cwe] + + references: List[str] = [] + meta_refs = _safe_get(r, ["extra", "metadata", "references"]) or [] + if isinstance(meta_refs, list): + references = [str(x) for x in meta_refs] + elif isinstance(meta_refs, str): + references = [meta_refs] + + # BIT heuristic + bit_trigger = message or (vuln_class[0] if vuln_class else None) + steps = [ + f"Open file `{path}`.", + f"Go to lines {start_line}–{end_line}.", + ] + if message: + steps.append(f"Observe: {message}") + if cwe: + steps.append(f"Related CWE: {', '.join(cwe)}") + bit_steps = steps + bit_reproduction = " / ".join(steps) + bit_severity = severity + + out.append( + BaseSnippet( + input=str(rule_id) if rule_id else "semgrep", + idx=idx, + path=path, + start_line=start_line, + end_line=end_line, + snippet=(snippet_text or "").strip() or None, + message=message, + vulnerability_class=vuln_class, + cwe=cwe, + severity=severity, + references=references, + bit_trigger=bit_trigger, + bit_steps=bit_steps, + bit_reproduction=bit_reproduction, + bit_severity=bit_severity, + constraints={}, + ) + ) + + return out diff --git a/src/autofic_core/sast/snippet.py b/src/autofic_core/sast/snippet.py index 89f15b7..a0ba9fd 100644 --- a/src/autofic_core/sast/snippet.py +++ b/src/autofic_core/sast/snippet.py @@ -1,50 +1,228 @@ -# ============================================================================= -# Copyright 2025 Autofic Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -"""Defines a unified BaseSnippet model for all SAST tool outputs.""" - -from pydantic import BaseModel, Field -from typing import List, Optional +from __future__ import annotations + +from typing import List, Dict, Optional, Any +from pydantic import BaseModel, Field, validator +from enum import Enum + + +class SeverityLevel(str, Enum): + INFO = "INFO" + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + CRITICAL = "CRITICAL" + +# canonical severity order for helpers +_SEVERITY_ORDER = { + "INFO": 0, + "LOW": 1, + "MEDIUM": 2, + "HIGH": 3, + "CRITICAL": 4, +} + + +def _normalize_severity(s: Optional[str]) -> Optional[str]: + if s is None: + return None + s_up = str(s).strip().upper() + # try to map common variants + if s_up in _SEVERITY_ORDER: + return s_up + # some scanners use numbers or words - attempt best-effort + if s_up.isdigit(): + n = int(s_up) + if n >= 9: + return "CRITICAL" + if n >= 7: + return "HIGH" + if n >= 4: + return "MEDIUM" + return "LOW" + # fallback heuristics + if "CRIT" in s_up: + return "CRITICAL" + if "HIGH" in s_up or "URGENT" in s_up: + return "HIGH" + if "MED" in s_up: + return "MEDIUM" + if "LOW" in s_up: + return "LOW" + if "INFO" in s_up: + return "INFO" + return s_up # unknown but return normalized-case string class BaseSnippet(BaseModel): """ - Unified structure for all vulnerability snippets from Semgrep, CodeQL, SnykCode, etc. - - Attributes: - input (str): Full source code of the file. - idx (int): Unique index of the snippet within the file. - start_line (int): Start line number of the vulnerable code. - end_line (int): End line number of the vulnerable code. - snippet (str): Vulnerable code snippet. - message (str): Description of the vulnerability. - severity (str): Severity level (e.g., HIGH, MEDIUM, LOW). - path (str): File path relative to the repository root. - vulnerability_class (List[str]): List of vulnerability types (e.g., SQL Injection). - cwe (List[str]): List of CWE identifiers. - references (List[str]): List of external reference links. + Standard snippet model for AutoFiC pipeline. + + 주요 필드: + - input: 원본 입력 식별자(예: semgrep rule id 또는 source filename) + - idx: 내부 인덱스 (선택) + - path: 대상 파일 경로 (repo 상대 경로) + - start_line, end_line: 1-based 라인 범위, inclusive + - snippet: 취약점이 포함된 코드/문맥 + - message: 탐지 메시지 / rule 설명 + - vulnerability_class: 취약점 분류(예: 'XSS','SQLi') + - cwe: CWE 식별자 목록(예: ['CWE-89']) + - severity: 전반적 심각도(자동 정규화) + - references: 외부 참조 URL 등 + - constraints: 향후 확장용 key/value 메타 + - BIT 관련 필드: bit_trigger, bit_steps, bit_reproduction, bit_severity + - kb_template: 외부 지식/템플릿 레퍼런스(옵션) + - context_tags: 간단한 태그 리스트(예: ['input-sanitization','auth']) """ - input: str - idx: Optional[int] = None - start_line: int - end_line: int - snippet: Optional[str] = None - message: str = "" - vulnerability_class: List[str] = Field(default_factory=list) - cwe: List[str] = Field(default_factory=list) - severity: str = "" - references: List[str] = Field(default_factory=list) - path: str + + # provenance / identifiers + input: Optional[str] = Field(None, description="원본 입력 식별자 (예: tool+rule id)") + idx: Optional[int] = Field(None, description="내부 인덱스(선택)") + + # location & code + path: str = Field("", description="대상 파일의 repo 상대 경로") + start_line: int = Field(0, description="시작 라인 (1-based, inclusive)") + end_line: int = Field(0, description="종료 라인 (1-based, inclusive)") + snippet: Optional[str] = Field(None, description="취약점이 포함된 코드 또는 코드 문맥") + + # basic vuln metadata + message: Optional[str] = Field(None, description="탐지 메시지 / rule 설명") + vulnerability_class: List[str] = Field(default_factory=list, description="취약점 분류(ex: XSS, SQLi)") + cwe: List[str] = Field(default_factory=list, description="CWE 식별자 목록") + severity: Optional[str] = Field(None, description="정규화된 심각도 (INFO/LOW/MEDIUM/HIGH/CRITICAL)") + references: List[str] = Field(default_factory=list, description="참조 URL 또는 문서 목록") + + # BIT (Team-Atlanta style) fields + bit_trigger: Optional[str] = Field(None, description="Trigger: 취약점 트리거/원인 요약") + bit_steps: List[str] = Field(default_factory=list, description="Steps: 재현 단계(순서 있는 리스트)") + bit_reproduction: Optional[str] = Field(None, description="Reproduction: 재현 설명(자유텍스트)") + bit_severity: Optional[str] = Field(None, description="BIT 내 별도 심각도 (선택적)") + + # extensibility + constraints: Dict[str, Any] = Field(default_factory=dict, description="확장 제약조건/메타 (key-value)") + kb_template: Optional[str] = Field(None, description="외부 KB 템플릿 식별자 또는 템플릿 내용") + context_tags: List[str] = Field(default_factory=list, description="간단한 컨텍스트 태그 목록") + + class Config: + # allow population by field name and arbitrary extra attributes (backwards compatibility) + allow_population_by_field_name = True + extra = "allow" + validate_assignment = True + arbitrary_types_allowed = True + + @validator("severity", pre=True, always=True) + def _validate_severity(cls, v): + if v is None: + return None + return _normalize_severity(v) + + @validator("bit_severity", pre=True, always=True) + def _validate_bit_severity(cls, v): + if v is None: + return None + return _normalize_severity(v) + + def worst_severity(self) -> Optional[str]: + """ + Return the worst (highest) severity between `severity` and `bit_severity`. + If both are None, returns None. + """ + sev_vals = [s for s in (self.severity, self.bit_severity) if s] + if not sev_vals: + return None + # use _SEVERITY_ORDER mapping; unknown strings treated as INFO (0) + worst_val = None + worst_idx = -1 + for s in sev_vals: + idx = _SEVERITY_ORDER.get(s.upper(), 0) + if idx > worst_idx: + worst_idx = idx + worst_val = s + return worst_val + + def to_dict(self, include_none: bool = False) -> Dict[str, Any]: + """ + Serialize snippet to dict. By default omits None fields unless include_none=True. + """ + d = self.dict(by_alias=True, exclude_none=not include_none) + return d + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BaseSnippet": + """ + Construct BaseSnippet from a plain dict. Useful for deserializing preprocessor outputs. + """ + return cls(**data) + + def merge_with(self, other: "BaseSnippet") -> "BaseSnippet": + """ + Simple helper that merges another snippet into this one. + NOTE: merger.py implements a more robust merge. This is a convenience function for + quick combination: it will: + - extend vulnerability_class, cwe, references, context_tags + - append bit_steps (unique-preserve-order) + - choose worst severity + - expand start/end line and snippet text + """ + # extend lists with uniqueness while preserving order + def _uniq_extend(base: List[Any], add: List[Any]): + seen = set(base) + for a in add: + if a not in seen: + base.append(a) + seen.add(a) + + _uniq_extend(self.vulnerability_class, other.vulnerability_class or []) + _uniq_extend(self.cwe, other.cwe or []) + _uniq_extend(self.references, other.references or []) + _uniq_extend(self.context_tags, other.context_tags or []) + + if other.bit_steps: + _uniq_extend(self.bit_steps, other.bit_steps) + + # merge snippet text + parts = [] + if self.snippet: + parts.append(self.snippet) + if other.snippet and other.snippet not in parts: + parts.append(other.snippet) + self.snippet = "\n".join(parts).strip() + + # choose worst severity + worst = BaseSnippet._choose_worst_severity(self.severity, other.severity, self.bit_severity, other.bit_severity) + self.severity = worst + + # bit fields: prefer existing, append if missing + if not self.bit_trigger and other.bit_trigger: + self.bit_trigger = other.bit_trigger + if not self.bit_reproduction and other.bit_reproduction: + self.bit_reproduction = other.bit_reproduction + + # expand range + self.start_line = min(self.start_line or other.start_line or 0, other.start_line or self.start_line or 0) + self.end_line = max(self.end_line or other.end_line or 0, other.end_line or self.end_line or 0) + + # merge constraints (simple shallow merge; callers may use merger._merge_constraints for collision handling) + if other.constraints: + self.constraints = {**self.constraints, **other.constraints} + + # inputs concat + if self.input and other.input: + self.input = f"{self.input};{other.input}" + elif other.input: + self.input = other.input + + return self + + @staticmethod + def _choose_worst_severity(*sevs: Optional[str]) -> Optional[str]: + worst = None + worst_idx = -1 + for s in filter(None, map(_normalize_severity, [s for s in (sevs or []) if s is not None])): + idx = _SEVERITY_ORDER.get(s.upper(), 0) + if idx > worst_idx: + worst_idx = idx + worst = s + return worst + +# convenience export +__all__ = ["BaseSnippet", "SeverityLevel", "_normalize_severity"] diff --git a/src/autofic_core/sast/snykcode/preprocessor.py b/src/autofic_core/sast/snykcode/preprocessor.py index d10fa2e..32dc82a 100644 --- a/src/autofic_core/sast/snykcode/preprocessor.py +++ b/src/autofic_core/sast/snykcode/preprocessor.py @@ -1,128 +1,166 @@ -# ============================================================================= -# Copyright 2025 AutoFiC Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -""" -SnykCodePreprocessor extracts and normalizes SARIF-based Snyk Code scan results -into BaseSnippet objects for downstream LLM and patching workflows. -""" +from __future__ import annotations -import json -import os from pathlib import Path -from typing import List, Any +from typing import Any, Dict, List, Optional, Union +import json +import re from autofic_core.sast.snippet import BaseSnippet +def _safe_get(d: Dict[str, Any], path: List[Union[str, int]], default=None): + cur: Any = d + for key in path: + if isinstance(key, int): + if not isinstance(cur, list) or key >= len(cur): + return default + cur = cur[key] + else: + if not isinstance(cur, dict) or key not in cur: + return default + cur = cur[key] + return cur + +def _map_severity_from_properties(props: Dict[str, Any]) -> Optional[str]: + cand = None + for key in ("severity", "problem.severity", "security-severity"): + if key in (props or {}): + cand = props[key] + break + if cand is None: + return None + s = str(cand).strip().upper() + if s.isdigit(): + n = int(s) + if n >= 9: + return "CRITICAL" + if n >= 7: + return "HIGH" + if n >= 4: + return "MEDIUM" + if n > 0: + return "LOW" + return "INFO" + if "CRIT" in s: + return "CRITICAL" + if "HIGH" in s: + return "HIGH" + if "MED" in s: + return "MEDIUM" + if "LOW" in s: + return "LOW" + if "INFO" in s: + return "INFO" + return s + +def _extract_cwe_from_tags(tags: List[str]) -> List[str]: + out: List[str] = [] + for t in tags or []: + m = re.search(r"cwe[-_/ ]?(\d+)", t, flags=re.I) + if m: + out.append(f"CWE-{m.group(1)}") + # unique preserve order + seen = set() + ret = [] + for x in out: + if x not in seen: + seen.add(x) + ret.append(x) + return ret -class SnykCodePreprocessor: - """ - Preprocesses Snyk Code SARIF results into structured BaseSnippet objects. - """ +class SnykCodePreprocessor: @staticmethod - def read_json_file(path: str) -> dict: - """ - Load JSON file from given path. - - Args: - path (str): Path to SARIF result file. - - Returns: - dict: Parsed JSON content. - """ - with open(path, "r", encoding="utf-8") as f: - return json.load(f) + def save_json_file(data: Dict[str, Any], path: Union[str, Path]) -> None: + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + with p.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) @staticmethod - def save_json_file(data: Any, path: str) -> None: - """ - Save data to a JSON file with UTF-8 encoding. - - Args: - data (Any): Serializable Python object. - path (str): Destination file path. - """ - os.makedirs(Path(path).parent, exist_ok=True) - with open(path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) + def preprocess(json_path: Union[str, Path], repo_root: Union[str, Path]) -> List[BaseSnippet]: + with Path(json_path).open("r", encoding="utf-8") as f: + data = json.load(f) + return SnykCodePreprocessor._parse(data) @staticmethod - def preprocess(input_json_path: str, base_dir: str = ".") -> List[BaseSnippet]: - """ - Convert Snyk SARIF JSON into BaseSnippet objects. - - Args: - input_json_path (str): Path to Snyk SARIF output file. - base_dir (str): Base path of the source repo. - - Returns: - List[BaseSnippet]: Parsed list of code vulnerability snippets. - """ - sarif = SnykCodePreprocessor.read_json_file(input_json_path) - base_path = Path(base_dir).resolve() - snippets: List[BaseSnippet] = [] - - for run in sarif.get("runs", []): - rules_map = { - rule.get("id"): rule - for rule in run.get("tool", {}).get("driver", {}).get("rules", []) - } - - for idx, result in enumerate(run.get("results", [])): - location = result.get("locations", [{}])[0].get("physicalLocation", {}) - region = location.get("region", {}) - file_uri = location.get("artifactLocation", {}).get("uri", "") - file_path = (base_path / file_uri).resolve() - - if not file_path.exists(): - continue # Skip non-existent files - - try: - lines = file_path.read_text(encoding="utf-8").splitlines() - except Exception as e: - continue # Skip unreadable files - - full_code = "\n".join(lines) - start_line = region.get("startLine", 0) - end_line = region.get("endLine", start_line) - snippet = "\n".join(lines[start_line - 1:end_line]) - - rule_id = result.get("ruleId", "") - rule = rules_map.get(rule_id, {}) - help_uri = rule.get("helpUri", "") - cwe_tags = rule.get("properties", {}).get("tags", []) - cwe = [ - t.split("/")[-1].replace("cwe-", "CWE-") - for t in cwe_tags - if "cwe" in t.lower() + def _parse(data: Dict[str, Any]) -> List[BaseSnippet]: + runs = data.get("runs") or [] + # (optional) build rule index + rule_index: Dict[str, Dict[str, Any]] = {} + for run in runs: + rules = _safe_get(run, ["tool", "driver", "rules"], []) + for rd in rules or []: + rid = rd.get("id") + if rid: + rule_index[rid] = rd + + out: List[BaseSnippet] = [] + for run in runs: + results = run.get("results") or [] + for idx, res in enumerate(results): + rule_id = res.get("ruleId") + msg = _safe_get(res, ["message", "text"]) or _safe_get(res, ["message", "markdown"]) or "" + + loc = _safe_get(res, ["locations", 0, "physicalLocation"], {}) or {} + path = _safe_get(loc, ["artifactLocation", "uri"]) or "" + region = _safe_get(loc, ["region"], {}) or {} + start_line = int(region.get("startLine") or 0) + end_line = int(region.get("endLine") or start_line) + snippet_text = _safe_get(region, ["snippet", "text"]) or "" + + rule_meta = rule_index.get(rule_id or "", {}) + props_rule = rule_meta.get("properties") or {} + tags = props_rule.get("tags") or [] + + props_res = res.get("properties") or {} + severity = _map_severity_from_properties(props_res) or _map_severity_from_properties(props_rule) + + help_uri = rule_meta.get("helpUri") + references: List[str] = [] + if help_uri: + references.append(str(help_uri)) + for t in tags: + if isinstance(t, str) and t.startswith("external/"): + references.append(t) + + cwe = _extract_cwe_from_tags(tags) + vuln_class = [rule_id] if rule_id else [] + + message = msg or _safe_get(rule_meta, ["shortDescription", "text"]) or "" + + # BIT + bit_trigger = message or (vuln_class[0] if vuln_class else None) + steps = [ + f"Open file `{path}`.", + f"Go to lines {start_line}–{end_line}.", ] - references = [help_uri] if help_uri else [] - - snippets.append(BaseSnippet( - input=full_code.strip(), - idx=idx, - start_line=start_line, - end_line=end_line, - snippet=snippet.strip(), - message=result.get("message", {}).get("text", ""), - severity=result.get("level", "").upper(), - path=file_uri, - vulnerability_class=[rule_id.split("/", 1)[-1]] if rule_id else [], - cwe=cwe, - references=references - )) - - return snippets \ No newline at end of file + if message: + steps.append(f"Observe: {message}") + if cwe: + steps.append(f"Related CWE: {', '.join(cwe)}") + bit_steps = steps + bit_reproduction = " / ".join(steps) + bit_severity = severity + + out.append( + BaseSnippet( + input=f"snyk:{rule_id}" if rule_id else "snyk", + idx=idx, + path=path, + start_line=start_line, + end_line=end_line, + snippet=(snippet_text or "").strip() or None, + message=message, + vulnerability_class=vuln_class, + cwe=cwe, + severity=severity, + references=references, + bit_trigger=bit_trigger, + bit_steps=bit_steps, + bit_reproduction=bit_reproduction, + bit_severity=bit_severity, + constraints={}, + ) + ) + + return out diff --git a/src/autofic_core/sast/xml_generator.py b/src/autofic_core/sast/xml_generator.py new file mode 100644 index 0000000..a631b85 --- /dev/null +++ b/src/autofic_core/sast/xml_generator.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import os +import datetime +import xml.etree.ElementTree as ET +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, List, Optional, Union + +from autofic_core.sast.snippet import BaseSnippet + +XML_NS = "urn:autofic:custom-context" +XSI_NS = "http://www.w3.org/2001/XMLSchema-instance" +ET.register_namespace("", XML_NS) +ET.register_namespace("xsi", XSI_NS) + + +@dataclass +class RenderOptions: + tool_name: str = "AutoFiC" + schema_location: str = "schemas/custom_context.xsd" + include_env: bool = True + include_tracking: bool = True + include_mitigations: bool = True + context_lines_before: int = 0 + context_lines_after: int = 0 + + +def _severity_pair(sev: Optional[str]) -> tuple[str, str]: + v = (sev or "").upper() or "UNKNOWN" + return v, v + + +def _as_snippets(items: Iterable[Union[BaseSnippet, dict]]) -> List[BaseSnippet]: + out: List[BaseSnippet] = [] + for s in items: + if isinstance(s, BaseSnippet): + out.append(s) + elif isinstance(s, dict): + out.append(BaseSnippet(**s)) + else: + raise TypeError(f"Unsupported snippet type: {type(s)}") + return out + + +def generate_custom_context( + merged_snippets: Iterable[Union[BaseSnippet, dict]], + output_path: Path, + schema_path: Optional[Path] = None, + options: Optional[RenderOptions] = None, +) -> Path: + """ + 병합된 스니펫으로 Team-Atlanta 스타일 CUSTOM_CONTEXT.xml 생성 + """ + opts = options or RenderOptions() + snippets = _as_snippets(merged_snippets) + + root = ET.Element(f"{{{XML_NS}}}CUSTOM_CONTEXT") + root.set("version", "1.1") + root.set(f"{{{XSI_NS}}}schemaLocation", f"{XML_NS} {opts.schema_location}") + + meta = ET.SubElement(root, f"{{{XML_NS}}}META") + meta.set("generatedAt", datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="seconds")) + meta.set("tool", opts.tool_name) + meta.set("count", str(len(snippets))) + + for s in snippets: + v = ET.SubElement(root, f"{{{XML_NS}}}VULNERABILITY") + v.set("id", f"{s.path}:{s.start_line}-{s.end_line}") + + f_el = ET.SubElement(v, f"{{{XML_NS}}}FILE") + f_el.set("path", s.path) + + r_el = ET.SubElement(v, f"{{{XML_NS}}}RANGE") + r_el.set("start", str(s.start_line)) + r_el.set("end", str(s.end_line)) + + overall, bit = _severity_pair(s.severity) + sev_el = ET.SubElement(v, f"{{{XML_NS}}}SEVERITY") + sev_el.set("overall", overall) + sev_el.set("bit", bit) + + msg_el = ET.SubElement(v, f"{{{XML_NS}}}MESSAGE") + message_text = s.message or "" + if " | " in message_text: + messages_el = ET.SubElement(msg_el, f"{{{XML_NS}}}MESSAGES") + for piece in [m.strip() for m in message_text.split("|") if m.strip()]: + item = ET.SubElement(messages_el, f"{{{XML_NS}}}ITEM") + item.text = piece + else: + msg_el.text = message_text + + snip_el = ET.SubElement(v, f"{{{XML_NS}}}SNIPPET") + snip_el.text = s.snippet or "" + + bit_el = ET.SubElement(v, f"{{{XML_NS}}}BIT") + trig_el = ET.SubElement(bit_el, f"{{{XML_NS}}}TRIGGER") + trig_el.text = message_text or "Vulnerability detected." + + steps_el = ET.SubElement(bit_el, f"{{{XML_NS}}}STEPS") + step = ET.SubElement(steps_el, f"{{{XML_NS}}}STEP") + if s.start_line == s.end_line: + step.text = f"Review line {s.start_line} in {s.path}" + else: + step.text = f"Review lines {s.start_line}-{s.end_line} in {s.path}" + + repro_el = ET.SubElement(bit_el, f"{{{XML_NS}}}REPRODUCTION") + repro_el.text = "Inspect the indicated code region and verify unsafe data flow or pattern." + + bit_sev = ET.SubElement(bit_el, f"{{{XML_NS}}}BIT_SEVERITY") + bit_sev.text = bit + + if s.vulnerability_class: + classes = ET.SubElement(v, f"{{{XML_NS}}}CLASSES") + for c in sorted(set(s.vulnerability_class)): + ce = ET.SubElement(classes, f"{{{XML_NS}}}CLASS") + ce.text = c + + if s.cwe: + we = ET.SubElement(v, f"{{{XML_NS}}}WEAKNESSES") + for cwe in sorted(set(s.cwe)): + ce = ET.SubElement(we, f"{{{XML_NS}}}CWE") + ce.set("id", cwe) + + if s.references: + refs = ET.SubElement(v, f"{{{XML_NS}}}REFERENCES") + for href in sorted(set(s.references)): + re = ET.SubElement(refs, f"{{{XML_NS}}}REF") + re.set("href", href) + + pre = ET.SubElement(v, f"{{{XML_NS}}}PRECONDITIONS") + it = ET.SubElement(pre, f"{{{XML_NS}}}ITEM") + it.text = "Authenticated user may be required depending on route." + + if opts.include_env: + env = ET.SubElement(v, f"{{{XML_NS}}}ENV") + runtime = ET.SubElement(env, f"{{{XML_NS}}}RUNTIME") + runtime.set("node", os.getenv("NODE_MAJOR", "unknown")) + runtime.set("os", os.name) + + if opts.include_mitigations: + mit = ET.SubElement(v, f"{{{XML_NS}}}MITIGATION") + summary = ET.SubElement(mit, f"{{{XML_NS}}}SUMMARY") + summary.text = "Apply minimal changes: sanitize inputs, use parameterized APIs, and enforce allowlists." + + if opts.include_tracking: + ET.SubElement(v, f"{{{XML_NS}}}TRACKING") + + if opts.context_lines_before or opts.context_lines_after: + ctx = ET.SubElement(v, f"{{{XML_NS}}}CONTEXT") + ctx.set("before", str(opts.context_lines_before)) + ctx.set("after", str(opts.context_lines_after)) + + output_path.parent.mkdir(parents=True, exist_ok=True) + ET.ElementTree(root).write(output_path, encoding="utf-8", xml_declaration=True) + return output_path + + +def render_custom_context( + merged_snippets: Iterable[Union[BaseSnippet, dict]], + output_path: Path, + schema_path: Optional[Path] = None, + options: Optional[RenderOptions] = None, +) -> Path: + return generate_custom_context( + merged_snippets=merged_snippets, + output_path=output_path, + schema_path=schema_path, + options=options, + ) + + +__all__ = [ + "RenderOptions", + "generate_custom_context", + "render_custom_context", +]