From ed82776c3bf9546078c89e6b273158a7b5a8c3e9 Mon Sep 17 00:00:00 2001 From: Mateus Latrova Date: Tue, 7 Nov 2023 20:32:38 -0300 Subject: [PATCH] wip --- pipreqs/pipreqs.py | 69 ++++++++++++++++++++++++------------------- tests/test_pipreqs.py | 12 ++++++-- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 4c94a23..4f3183b 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -35,7 +35,7 @@ | e.g. Flask~=1.1.2 | e.g. Flask>=1.1.2 | e.g. Flask - --ignore-notebooks Ignore jupyter notebook files. + --scan-notebooks Look for imports in jupyter notebook files. """ from contextlib import contextmanager import os @@ -49,13 +49,6 @@ from yarg import json2package from yarg.exceptions import HTTPError -try: - PythonExporter = None - ignore_notebooks = False - from nbconvert import PythonExporter -except ImportError: - pass - from pipreqs import __version__ REGEXP = [ @@ -63,6 +56,11 @@ re.compile(r"^from ((?!\.+).*?) import (?:.*)$"), ] +scan_noteboooks = False + + +class NbconvertNotInstalled(ImportError): + pass @contextmanager def _open(filename=None, mode="r"): @@ -94,7 +92,6 @@ def _open(filename=None, mode="r"): if file not in (sys.stdin, sys.stdout): file.close() - def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links=True): imports = set() raw_imports = set() @@ -117,31 +114,22 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links ignore_dirs_parsed.append(os.path.basename(os.path.realpath(e))) ignore_dirs.extend(ignore_dirs_parsed) + extensions = get_file_extensions() + walk = os.walk(path, followlinks=follow_links) for root, dirs, files in walk: dirs[:] = [d for d in dirs if d not in ignore_dirs] candidates.append(os.path.basename(root)) - if notebooks_are_enabled(): - files = [fn for fn in files if file_ext_is_allowed(fn, [".py", ".ipynb"])] - else: - files = [fn for fn in files if file_ext_is_allowed(fn, [".py"])] + py_files = [file for file in files if file_ext_is_allowed(file, [".py"])] + candidates.extend([os.path.splitext(filename)[0] for filename in py_files]) - candidates = list( - map( - lambda fn: os.path.splitext(fn)[0], - filter(lambda fn: file_ext_is_allowed(fn, [".py"]), files), - ) - ) + files = [fn for fn in files if file_ext_is_allowed(fn, extensions)] for file_name in files: file_name = os.path.join(root, file_name) - contents = "" - if file_ext_is_allowed(file_name, [".py"]): - with open(file_name, "r", encoding=encoding) as f: - contents = f.read() - elif file_ext_is_allowed(file_name, [".ipynb"]) and notebooks_are_enabled(): - contents = ipynb_2_py(file_name, encoding=encoding) + contents = read_file_content(file_name, encoding) + try: tree = ast.parse(contents) for node in ast.walk(tree): @@ -177,10 +165,16 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links return list(packages - data) +def get_file_extensions(): + return [".py", ".ipynb"] if scan_noteboooks else [".py"] -def notebooks_are_enabled(): - return PythonExporter and not ignore_notebooks - +def read_file_content(file_name: str, encoding="utf-8"): + if file_ext_is_allowed(file_name, [".py"]): + with open(file_name, "r", encoding=encoding) as f: + contents = f.read() + elif file_ext_is_allowed(file_name, [".ipynb"]) and scan_noteboooks: + contents = ipynb_2_py(file_name, encoding=encoding) + return contents def file_ext_is_allowed(file_name, acceptable): return os.path.splitext(file_name)[1] in acceptable @@ -197,7 +191,6 @@ def ipynb_2_py(file_name, encoding="utf-8"): str: parsed string """ - exporter = PythonExporter() (body, _) = exporter.from_filename(file_name) @@ -487,13 +480,27 @@ def dynamic_versioning(scheme, imports): symbol = "~=" return imports, symbol +def handle_scan_noteboooks(): + if not scan_noteboooks: + logging.info("Not scanning for jupyter notebooks.") + return + + try: + global PythonExporter + from nbconvert import PythonExporter + except ImportError: + error = NbconvertNotInstalled("In order to scan jupyter notebooks, please install the nbconvert and ipython libraries") + raise error def init(args): - global ignore_notebooks + global scan_noteboooks encoding = args.get("--encoding") extra_ignore_dirs = args.get("--ignore") follow_links = not args.get("--no-follow-links") - ignore_notebooks = args.get("--ignore-notebooks") + + scan_noteboooks = args.get("--scan-notebooks", False) + handle_scan_noteboooks() + input_path = args[""] if encoding is None: diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index 02faf1d..6c273bf 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -9,7 +9,7 @@ """ from io import StringIO -from unittest.mock import patch +from unittest.mock import patch, Mock import unittest import os import requests @@ -519,6 +519,7 @@ def test_import_notebooks(self): """ Test the function get_all_imports() using .ipynb file """ + self.mock_scan_notebooks() imports = pipreqs.get_all_imports(self.project_with_notebooks, encoding="utf-8") self.assertEqual(len(imports), 13) for item in imports: @@ -530,6 +531,7 @@ def test_invalid_notebook(self): """ Test that invalid notebook files cannot be imported. """ + self.mock_scan_notebooks() self.assertRaises(SyntaxError, pipreqs.get_all_imports, self.project_with_invalid_notebooks) def test_ipynb_2_py(self): @@ -595,7 +597,7 @@ def test_parse_requirements_handles_file_not_found(self, exit_mock): def test_ignore_notebooks(self): """ - Test the --ignore-notebooks parameter + Test if notebooks are ignored when the scan-notebooks parameter is False """ pipreqs.init( { @@ -609,12 +611,16 @@ def test_ignore_notebooks(self): "--diff": None, "--clean": None, "--mode": None, - "--ignore-notebooks": True, + "--scan-notebooks": False, } ) assert os.path.exists(self.requirements_notebook_path) == 1 assert os.path.getsize(self.requirements_notebook_path) <= 1 + def mock_scan_notebooks(self): + pipreqs.scan_noteboooks = Mock(return_value=True) + pipreqs.handle_scan_noteboooks() + def tearDown(self): """ Remove requiremnts.txt files that were written