From 472be9e4e782b89b7f216c91fa9941c2b441eeeb Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Wed, 8 Jan 2025 23:46:24 +0100 Subject: [PATCH] feat: add support for improved handling of jupyter notebooks (#105) --- src/gitingest/__init__.py | 4 +- src/gitingest/notebook_utils.py | 66 +++++++++ src/gitingest/query_ingestion.py | 8 +- src/gitingest/repository_ingest.py | 4 +- src/query_processor.py | 4 +- tests/conftest.py | 16 +++ tests/test_notebook_utils.py | 206 +++++++++++++++++++++++++++++ tests/test_query_ingestion.py | 22 ++- 8 files changed, 321 insertions(+), 9 deletions(-) create mode 100644 src/gitingest/notebook_utils.py create mode 100644 tests/test_notebook_utils.py diff --git a/src/gitingest/__init__.py b/src/gitingest/__init__.py index afccf41..c592350 100644 --- a/src/gitingest/__init__.py +++ b/src/gitingest/__init__.py @@ -1,8 +1,8 @@ """ Gitingest: A package for ingesting data from git repositories. """ -from gitingest.query_ingestion import ingest_from_query +from gitingest.query_ingestion import run_ingest_query from gitingest.query_parser import parse_query from gitingest.repository_clone import clone_repo from gitingest.repository_ingest import ingest -__all__ = ["ingest_from_query", "clone_repo", "parse_query", "ingest"] +__all__ = ["run_ingest_query", "clone_repo", "parse_query", "ingest"] diff --git a/src/gitingest/notebook_utils.py b/src/gitingest/notebook_utils.py new file mode 100644 index 0000000..c559034 --- /dev/null +++ b/src/gitingest/notebook_utils.py @@ -0,0 +1,66 @@ +""" Utilities for processing Jupyter notebooks. """ + +import json +import warnings +from pathlib import Path +from typing import Any + + +def process_notebook(file: Path) -> str: + """ + Process a Jupyter notebook file and return an executable Python script as a string. + + Parameters + ---------- + file : Path + The path to the Jupyter notebook file. + + Returns + ------- + str + The executable Python script as a string. + + Raises + ------ + ValueError + If an unexpected cell type is encountered. + """ + with file.open(encoding="utf-8") as f: + notebook: dict[str, Any] = json.load(f) + + # Check if the notebook contains worksheets + if worksheets := notebook.get("worksheets"): + # https://github.com/ipython/ipython/wiki/IPEP-17:-Notebook-Format-4#remove-multiple-worksheets + # "The `worksheets` field is a list, but we have no UI to support multiple worksheets. + # Our design has since shifted to heading-cell based structure, so we never intend to + # support the multiple worksheet model. The worksheets list of lists shall be replaced + # with a single list, called `cells`." + warnings.warn("Worksheets are deprecated as of IPEP-17.", DeprecationWarning) + + if len(worksheets) > 1: + warnings.warn( + "Multiple worksheets are not supported. Only the first worksheet will be processed.", UserWarning + ) + + notebook = worksheets[0] + + result = [] + + for cell in notebook["cells"]: + cell_type = cell.get("cell_type") + + # Validate cell type and handle unexpected types + if cell_type not in ("markdown", "code", "raw"): + raise ValueError(f"Unknown cell type: {cell_type}") + + str_ = "".join(cell.get("source", [])) + if not str_: + continue + + # Convert Markdown and raw cells to multi-line comments + if cell_type in ("markdown", "raw"): + str_ = f'"""\n{str_}\n"""' + + result.append(str_) + + return "\n\n".join(result) diff --git a/src/gitingest/query_ingestion.py b/src/gitingest/query_ingestion.py index ff4e483..c58ea81 100644 --- a/src/gitingest/query_ingestion.py +++ b/src/gitingest/query_ingestion.py @@ -7,6 +7,7 @@ import tiktoken from gitingest.exceptions import AlreadyVisitedError, MaxFileSizeReachedError, MaxFilesReachedError +from gitingest.notebook_utils import process_notebook MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB MAX_DIRECTORY_DEPTH = 20 # Maximum depth of directory traversal @@ -158,7 +159,10 @@ def _read_file_content(file_path: Path) -> str: The content of the file, or an error message if the file could not be read. """ try: - with file_path.open(encoding="utf-8", errors="ignore") as f: + if file_path.suffix == ".ipynb": + return process_notebook(file_path) + + with open(file_path, encoding="utf-8", errors="ignore") as f: return f.read() except OSError as e: return f"Error reading file: {e}" @@ -819,7 +823,7 @@ def _ingest_directory(path: Path, query: dict[str, Any]) -> tuple[str, str, str] return summary, tree, files_content -def ingest_from_query(query: dict[str, Any]) -> tuple[str, str, str]: +def run_ingest_query(query: dict[str, Any]) -> tuple[str, str, str]: """ Main entry point for analyzing a codebase directory or single file. diff --git a/src/gitingest/repository_ingest.py b/src/gitingest/repository_ingest.py index 11f58eb..e2cecaa 100644 --- a/src/gitingest/repository_ingest.py +++ b/src/gitingest/repository_ingest.py @@ -5,7 +5,7 @@ import shutil from config import TMP_BASE_PATH -from gitingest.query_ingestion import ingest_from_query +from gitingest.query_ingestion import run_ingest_query from gitingest.query_parser import parse_query from gitingest.repository_clone import CloneConfig, clone_repo @@ -75,7 +75,7 @@ def ingest( else: raise TypeError("clone_repo did not return a coroutine as expected.") - summary, tree, content = ingest_from_query(query) + summary, tree, content = run_ingest_query(query) if output is not None: with open(output, "w", encoding="utf-8") as f: diff --git a/src/query_processor.py b/src/query_processor.py index 2e12909..f6c7df8 100644 --- a/src/query_processor.py +++ b/src/query_processor.py @@ -7,7 +7,7 @@ from starlette.templating import _TemplateResponse from config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE -from gitingest.query_ingestion import ingest_from_query +from gitingest.query_ingestion import run_ingest_query from gitingest.query_parser import parse_query from gitingest.repository_clone import CloneConfig, clone_repo from server_utils import Colors, log_slider_to_size @@ -91,7 +91,7 @@ async def process_query( branch=query.get("branch"), ) await clone_repo(clone_config) - summary, tree, content = ingest_from_query(query) + summary, tree, content = run_ingest_query(query) with open(f"{clone_config.local_path}.txt", "w", encoding="utf-8") as f: f.write(tree + "\n" + content) except Exception as e: diff --git a/tests/conftest.py b/tests/conftest.py index c05ebcc..87b8a4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ """ This module contains fixtures for the tests. """ +import json from pathlib import Path from typing import Any @@ -72,3 +73,18 @@ def temp_directory(tmp_path: Path) -> Path: (dir2 / "file_dir2.txt").write_text("Hello from dir2") return test_dir + + +@pytest.fixture +def write_notebook(tmp_path: Path): + """ + A fixture that returns a helper function to write a .ipynb notebook file at runtime with given content. + """ + + def _write_notebook(name: str, content: dict[str, Any]) -> Path: + notebook_path = tmp_path / name + with notebook_path.open(mode="w", encoding="utf-8") as f: + json.dump(content, f) + return notebook_path + + return _write_notebook diff --git a/tests/test_notebook_utils.py b/tests/test_notebook_utils.py new file mode 100644 index 0000000..a0da1b1 --- /dev/null +++ b/tests/test_notebook_utils.py @@ -0,0 +1,206 @@ +""" Tests for the notebook_utils module. """ + +import pytest + +from gitingest.notebook_utils import process_notebook + + +def test_process_notebook_all_cells(write_notebook): + """ + Test a notebook containing markdown, code, and raw cells. + + - Markdown/raw cells => triple-quoted + - Code cells => remain normal code + - For 1 markdown + 1 raw => 2 triple-quoted blocks => 4 occurrences of triple-quotes. + """ + notebook_content = { + "cells": [ + {"cell_type": "markdown", "source": ["# Markdown cell"]}, + {"cell_type": "code", "source": ['print("Hello Code")']}, + {"cell_type": "raw", "source": [""]}, + ] + } + nb_path = write_notebook("all_cells.ipynb", notebook_content) + result = process_notebook(nb_path) + + assert result.count('"""') == 4, "Expected 4 triple-quote occurrences for 2 blocks." + + # Check that markdown and raw content are inside triple-quoted blocks + assert "# Markdown cell" in result + assert "" in result + + # Check code cell is present and not wrapped in triple quotes + assert 'print("Hello Code")' in result + assert '"""\nprint("Hello Code")\n"""' not in result + + +def test_process_notebook_with_worksheets(write_notebook): + """ + Test a notebook containing the 'worksheets' key (deprecated as of IPEP-17). + + - Should raise a DeprecationWarning. + - We process only the first (and only) worksheet's cells. + - The resulting content matches an equivalent notebook with "cells" at top level. + """ + with_worksheets = { + "worksheets": [ + { + "cells": [ + {"cell_type": "markdown", "source": ["# Markdown cell"]}, + {"cell_type": "code", "source": ['print("Hello Code")']}, + {"cell_type": "raw", "source": [""]}, + ] + } + ] + } + without_worksheets = with_worksheets["worksheets"][0] # same, but no 'worksheets' key at top + + nb_with = write_notebook("with_worksheets.ipynb", with_worksheets) + nb_without = write_notebook("without_worksheets.ipynb", without_worksheets) + + with pytest.warns(DeprecationWarning, match="Worksheets are deprecated as of IPEP-17."): + result_with = process_notebook(nb_with) + + # No warnings here + result_without = process_notebook(nb_without) + + assert result_with == result_without, "Both notebooks should produce identical content." + + +def test_process_notebook_multiple_worksheets(write_notebook): + """ + Test a notebook containing multiple 'worksheets'. + + If multiple worksheets are present: + - Only process the first sheet's cells. + - DeprecationWarning for worksheets + - UserWarning for ignoring extra worksheets + """ + multi_worksheets = { + "worksheets": [ + {"cells": [{"cell_type": "markdown", "source": ["# First Worksheet"]}]}, + {"cells": [{"cell_type": "code", "source": ['print("Ignored Worksheet")']}]}, + ] + } + + # Single-worksheet version (only the first) + single_worksheet = { + "worksheets": [ + {"cells": [{"cell_type": "markdown", "source": ["# First Worksheet"]}]}, + ] + } + + nb_multi = write_notebook("multiple_worksheets.ipynb", multi_worksheets) + nb_single = write_notebook("single_worksheet.ipynb", single_worksheet) + + with pytest.warns(DeprecationWarning, match="Worksheets are deprecated as of IPEP-17."): + with pytest.warns(UserWarning, match="Multiple worksheets are not supported."): + result_multi = process_notebook(nb_multi) + + with pytest.warns(DeprecationWarning, match="Worksheets are deprecated as of IPEP-17."): + result_single = process_notebook(nb_single) + + # The second worksheet (with code) should have been ignored + assert result_multi == result_single, "Second worksheet was ignored, results match." + + +def test_process_notebook_code_only(write_notebook): + """ + Test a notebook containing only code cells. + + No triple quotes should appear. + """ + notebook_content = { + "cells": [ + {"cell_type": "code", "source": ["print('Code Cell 1')"]}, + {"cell_type": "code", "source": ["x = 42"]}, + ] + } + nb_path = write_notebook("code_only.ipynb", notebook_content) + result = process_notebook(nb_path) + + # No triple quotes + assert '"""' not in result + assert "print('Code Cell 1')" in result + assert "x = 42" in result + + +def test_process_notebook_markdown_only(write_notebook): + """ + Test a notebook with 2 markdown cells. + + 2 markdown cells => each becomes 1 triple-quoted block => 2 blocks => 4 triple quotes. + """ + notebook_content = { + "cells": [ + {"cell_type": "markdown", "source": ["# Markdown Header"]}, + {"cell_type": "markdown", "source": ["Some more markdown."]}, + ] + } + nb_path = write_notebook("markdown_only.ipynb", notebook_content) + result = process_notebook(nb_path) + + assert result.count('"""') == 4, "Two markdown cells => two triple-quoted blocks => 4 triple quotes total." + assert "# Markdown Header" in result + assert "Some more markdown." in result + + +def test_process_notebook_raw_only(write_notebook): + """ + Test a notebook with 2 raw cells. + + 2 raw cells => 2 blocks => 4 triple quotes. + """ + notebook_content = { + "cells": [ + {"cell_type": "raw", "source": ["Raw content line 1"]}, + {"cell_type": "raw", "source": ["Raw content line 2"]}, + ] + } + nb_path = write_notebook("raw_only.ipynb", notebook_content) + result = process_notebook(nb_path) + + # 2 raw cells => 2 triple-quoted blocks => 4 occurrences + assert result.count('"""') == 4 + assert "Raw content line 1" in result + assert "Raw content line 2" in result + + +def test_process_notebook_empty_cells(write_notebook): + """ + Test that cells with an empty 'source' are skipped entirely. + + 4 cells but 3 are empty => only 1 non-empty cell => 1 triple-quoted block => 2 quotes. + """ + notebook_content = { + "cells": [ + {"cell_type": "markdown", "source": []}, + {"cell_type": "code", "source": []}, + {"cell_type": "raw", "source": []}, + {"cell_type": "markdown", "source": ["# Non-empty markdown"]}, + ] + } + nb_path = write_notebook("empty_cells.ipynb", notebook_content) + result = process_notebook(nb_path) + + # Only one non-empty markdown cell => 1 block => 2 triple quotes + assert result.count('"""') == 2 + assert "# Non-empty markdown" in result + + +def test_process_notebook_invalid_cell_type(write_notebook): + """ + Test a notebook with an unknown cell type. + + Should raise a ValueError. + """ + notebook_content = { + "cells": [ + {"cell_type": "markdown", "source": ["# Valid markdown"]}, + {"cell_type": "unknown", "source": ["Unrecognized cell type"]}, + ] + } + nb_path = write_notebook("invalid_cell_type.ipynb", notebook_content) + + with pytest.raises(ValueError, match="Unknown cell type: unknown"): + process_notebook(nb_path) diff --git a/tests/test_query_ingestion.py b/tests/test_query_ingestion.py index 886dafc..48edbc2 100644 --- a/tests/test_query_ingestion.py +++ b/tests/test_query_ingestion.py @@ -2,8 +2,9 @@ from pathlib import Path from typing import Any +from unittest.mock import patch -from gitingest.query_ingestion import _extract_files_content, _scan_directory +from gitingest.query_ingestion import _extract_files_content, _read_file_content, _scan_directory def test_scan_directory(temp_directory: Path, sample_query: dict[str, Any]) -> None: @@ -37,6 +38,25 @@ def test_extract_files_content(temp_directory: Path, sample_query: dict[str, Any assert any("file_dir2.txt" in p for p in paths) +def test_read_file_content_with_notebook(tmp_path: Path): + notebook_path = tmp_path / "dummy_notebook.ipynb" + notebook_path.write_text("{}", encoding="utf-8") # minimal JSON + + # Patch the symbol as it is used in ingest_from_query + with patch("gitingest.ingest_from_query.process_notebook") as mock_process: + _read_file_content(notebook_path) + mock_process.assert_called_once_with(notebook_path) + + +def test_read_file_content_with_non_notebook(tmp_path: Path): + py_file_path = tmp_path / "dummy_file.py" + py_file_path.write_text("print('Hello')", encoding="utf-8") + + with patch("gitingest.ingest_from_query.process_notebook") as mock_process: + _read_file_content(py_file_path) + mock_process.assert_not_called() + + # TODO: test with include patterns: ['*.txt'] # TODO: test with wrong include patterns: ['*.qwerty']