diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4c9ae7e..df5f4a4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,7 +39,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install coverage docopt yarg requests + pip install coverage docopt yarg requests nbconvert - name: Calculate coverage run: coverage run --source=pipreqs -m unittest discover diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index a84f39b..74da720 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -47,17 +47,18 @@ import requests from yarg import json2package from yarg.exceptions import HTTPError +from nbconvert import PythonExporter from pipreqs import __version__ REGEXP = [ - re.compile(r'^import (.+)$'), - re.compile(r'^from ((?!\.+).*?) import (?:.*)$') + re.compile(r"^import (.+)$"), + re.compile(r"^from ((?!\.+).*?) import (?:.*)$"), ] @contextmanager -def _open(filename=None, mode='r'): +def _open(filename=None, mode="r"): """Open a file or ``sys.stdout`` depending on the provided filename. Args: @@ -70,13 +71,13 @@ def _open(filename=None, mode='r'): A file handle. """ - if not filename or filename == '-': - if not mode or 'r' in mode: + if not filename or filename == "-": + if not mode or "r" in mode: file = sys.stdin - elif 'w' in mode: + elif "w" in mode: file = sys.stdout else: - raise ValueError('Invalid mode for file: {}'.format(mode)) + raise ValueError("Invalid mode for file: {}".format(mode)) else: file = open(filename, mode) @@ -87,13 +88,21 @@ def _open(filename=None, mode='r'): file.close() -def get_all_imports( - path, encoding=None, extra_ignore_dirs=None, follow_links=True): +def get_all_imports(path, encoding=None, extra_ignore_dirs=None, follow_links=True): imports = set() raw_imports = set() candidates = [] ignore_errors = False - ignore_dirs = [".hg", ".svn", ".git", ".tox", "__pycache__", "env", "venv"] + ignore_dirs = [ + ".hg", + ".svn", + ".git", + ".tox", + "__pycache__", + "env", + "venv", + ".ipynb_checkpoints", + ] if extra_ignore_dirs: ignore_dirs_parsed = [] @@ -106,13 +115,23 @@ def get_all_imports( dirs[:] = [d for d in dirs if d not in ignore_dirs] candidates.append(os.path.basename(root)) - files = [fn for fn in files if os.path.splitext(fn)[1] == ".py"] + files = [fn for fn in files if filter_ext(fn, [".py", ".ipynb"])] + + candidates = list( + map( + lambda fn: os.path.splitext(fn)[0], + filter(lambda fn: filter_ext(fn, [".py"]), files), + ) + ) - candidates += [os.path.splitext(fn)[0] for fn in files] for file_name in files: file_name = os.path.join(root, file_name) - with open(file_name, "r", encoding=encoding) as f: - contents = f.read() + contents = "" + if filter_ext(file_name, [".py"]): + with open(file_name, "r", encoding=encoding) as f: + contents = f.read() + elif filter_ext(file_name, [".ipynb"]): + contents = ipynb_2_py(file_name, encoding=encoding) try: tree = ast.parse(contents) for node in ast.walk(tree): @@ -128,6 +147,8 @@ def get_all_imports( continue else: logging.error("Failed on file: %s" % file_name) + if filter_ext(file_name, [".ipynb"]): + logging.error("Magic command without % might be failed") raise exc # Clean up imports @@ -137,11 +158,11 @@ def get_all_imports( # Cleanup: We only want to first part of the import. # Ex: from django.conf --> django.conf. But we only want django # as an import. - cleaned_name, _, _ = name.partition('.') + cleaned_name, _, _ = name.partition(".") imports.add(cleaned_name) packages = imports - (set(candidates) & imports) - logging.debug('Found packages: {0}'.format(packages)) + logging.debug("Found packages: {0}".format(packages)) with open(join("stdlib"), "r") as f: data = {x.strip() for x in f} @@ -149,58 +170,81 @@ def get_all_imports( return list(packages - data) +def filter_line(line): + return len(line) > 0 and line[0] != "#" + + +def filter_ext(file_name, acceptable): + return os.path.splitext(file_name)[1] in acceptable + + +def ipynb_2_py(file_name, encoding=None): + """ + + Args: + file_name (str): notebook file path to parse as python script + encoding (str): encoding of file + + Returns: + str: parsed string + + """ + + exporter = PythonExporter() + (body, _) = exporter.from_filename(file_name) + + return body.encode(encoding if encoding is not None else "utf-8") + + def generate_requirements_file(path, imports, symbol): with _open(path, "w") as out_file: - logging.debug('Writing {num} requirements: {imports} to {file}'.format( - num=len(imports), - file=path, - imports=", ".join([x['name'] for x in imports]) - )) - fmt = '{name}' + symbol + '{version}' - out_file.write('\n'.join( - fmt.format(**item) if item['version'] else '{name}'.format(**item) - for item in imports) + '\n') + logging.debug( + "Writing {num} requirements: {imports} to {file}".format( + num=len(imports), + file=path, + imports=", ".join([x["name"] for x in imports]), + ) + ) + fmt = "{name}" + symbol + "{version}" + out_file.write( + "\n".join(fmt.format(**item) if item["version"] else "{name}".format(**item) for item in imports) + "\n" + ) def output_requirements(imports, symbol): - generate_requirements_file('-', imports, symbol) + generate_requirements_file("-", imports, symbol) -def get_imports_info( - imports, pypi_server="https://pypi.python.org/pypi/", proxy=None): +def get_imports_info(imports, pypi_server="https://pypi.python.org/pypi/", proxy=None): result = [] for item in imports: try: logging.warning( - 'Import named "%s" not found locally. ' - 'Trying to resolve it at the PyPI server.', - item + 'Import named "%s" not found locally. ' "Trying to resolve it at the PyPI server.", + item, ) - response = requests.get( - "{0}{1}/json".format(pypi_server, item), proxies=proxy) + response = requests.get("{0}{1}/json".format(pypi_server, item), proxies=proxy) if response.status_code == 200: - if hasattr(response.content, 'decode'): + if hasattr(response.content, "decode"): data = json2package(response.content.decode()) else: data = json2package(response.content) elif response.status_code >= 300: - raise HTTPError(status_code=response.status_code, - reason=response.reason) + raise HTTPError(status_code=response.status_code, reason=response.reason) except HTTPError: - logging.warning( - 'Package "%s" does not exist or network problems', item) + logging.warning('Package "%s" does not exist or network problems', item) continue logging.warning( 'Import named "%s" was resolved to "%s:%s" package (%s).\n' - 'Please, verify manually the final list of requirements.txt ' - 'to avoid possible dependency confusions.', + "Please, verify manually the final list of requirements.txt " + "to avoid possible dependency confusions.", item, data.name, data.latest_release_id, - data.pypi_url + data.pypi_url, ) - result.append({'name': item, 'version': data.latest_release_id}) + result.append({"name": item, "version": data.latest_release_id}) return result @@ -225,25 +269,23 @@ def get_locally_installed_packages(encoding=None): filtered_top_level_modules = list() for module in top_level_modules: - if ( - (module not in ignore) and - (package[0] not in ignore) - ): + if (module not in ignore) and (package[0] not in ignore): # append exported top level modules to the list filtered_top_level_modules.append(module) version = None if len(package) > 1: - version = package[1].replace( - ".dist", "").replace(".egg", "") + version = package[1].replace(".dist", "").replace(".egg", "") # append package: top_level_modules pairs # instead of top_level_module: package pairs - packages.append({ - 'name': package[0], - 'version': version, - 'exports': filtered_top_level_modules - }) + packages.append( + { + "name": package[0], + "version": version, + "exports": filtered_top_level_modules, + } + ) return packages @@ -256,14 +298,14 @@ def get_import_local(imports, encoding=None): # if candidate import name matches export name # or candidate import name equals to the package name # append it to the result - if item in package['exports'] or item == package['name']: + if item in package["exports"] or item == package["name"]: result.append(package) # removing duplicates of package/version # had to use second method instead of the previous one, # because we have a list in the 'exports' field # https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python - result_unique = [i for n, i in enumerate(result) if i not in result[n+1:]] + result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]] return result_unique @@ -294,7 +336,7 @@ def get_name_without_alias(name): match = REGEXP[0].match(name.strip()) if match: name = match.groups(0)[0] - return name.partition(' as ')[0].partition('.')[0].strip() + return name.partition(" as ")[0].partition(".")[0].strip() def join(f): @@ -353,6 +395,7 @@ def parse_requirements(file_): return modules + def compare_modules(file_, imports): """Compare modules in a file to imported modules in a project. @@ -379,7 +422,8 @@ def diff(file_, imports): logging.info( "The following modules are in {} but do not seem to be imported: " - "{}".format(file_, ", ".join(x for x in modules_not_imported))) + "{}".format(file_, ", ".join(x for x in modules_not_imported)) + ) def clean(file_, imports): @@ -427,30 +471,27 @@ def dynamic_versioning(scheme, imports): def init(args): - encoding = args.get('--encoding') - extra_ignore_dirs = args.get('--ignore') - follow_links = not args.get('--no-follow-links') - input_path = args[''] + encoding = args.get("--encoding") + extra_ignore_dirs = args.get("--ignore") + follow_links = not args.get("--no-follow-links") + input_path = args[""] if input_path is None: input_path = os.path.abspath(os.curdir) if extra_ignore_dirs: - extra_ignore_dirs = extra_ignore_dirs.split(',') - - path = (args["--savepath"] if args["--savepath"] else - os.path.join(input_path, "requirements.txt")) - if (not args["--print"] - and not args["--savepath"] - and not args["--force"] - and os.path.exists(path)): - logging.warning("requirements.txt already exists, " - "use --force to overwrite it") + extra_ignore_dirs = extra_ignore_dirs.split(",") + + path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") + if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path): + logging.warning("requirements.txt already exists, " "use --force to overwrite it") return - candidates = get_all_imports(input_path, - encoding=encoding, - extra_ignore_dirs=extra_ignore_dirs, - follow_links=follow_links) + candidates = get_all_imports( + input_path, + encoding=encoding, + extra_ignore_dirs=extra_ignore_dirs, + follow_links=follow_links, + ) candidates = get_pkg_names(candidates) logging.debug("Found imports: " + ", ".join(candidates)) pypi_server = "https://pypi.python.org/pypi/" @@ -459,11 +500,10 @@ def init(args): pypi_server = args["--pypi-server"] if args["--proxy"]: - proxy = {'http': args["--proxy"], 'https': args["--proxy"]} + proxy = {"http": args["--proxy"], "https": args["--proxy"]} if args["--use-local"]: - logging.debug( - "Getting package information ONLY from local installation.") + logging.debug("Getting package information ONLY from local installation.") imports = get_import_local(candidates, encoding=encoding) else: logging.debug("Getting packages information from Local/PyPI") @@ -473,20 +513,21 @@ def init(args): # the list of exported modules, installed locally # and the package name is not in the list of local module names # it add to difference - difference = [x for x in candidates if - # aggregate all export lists into one - # flatten the list - # check if candidate is in exports - x.lower() not in [y for x in local for y in x['exports']] - and - # check if candidate is package names - x.lower() not in [x['name'] for x in local]] - - imports = local + get_imports_info(difference, - proxy=proxy, - pypi_server=pypi_server) + difference = [ + x + for x in candidates + if + # aggregate all export lists into one + # flatten the list + # check if candidate is in exports + x.lower() not in [y for x in local for y in x["exports"]] and + # check if candidate is package names + x.lower() not in [x["name"] for x in local] + ] + + imports = local + get_imports_info(difference, proxy=proxy, pypi_server=pypi_server) # sort imports based on lowercase name of package, similar to `pip freeze`. - imports = sorted(imports, key=lambda x: x['name'].lower()) + imports = sorted(imports, key=lambda x: x["name"].lower()) if args["--diff"]: diff(args["--diff"], imports) @@ -501,8 +542,7 @@ def init(args): if scheme in ["compat", "gt", "no-pin"]: imports, symbol = dynamic_versioning(scheme, imports) else: - raise ValueError("Invalid argument for mode flag, " - "use 'compat', 'gt' or 'no-pin' instead") + raise ValueError("Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead") else: symbol = "==" @@ -516,8 +556,8 @@ def init(args): def main(): # pragma: no cover args = docopt(__doc__, version=__version__) - log_level = logging.DEBUG if args['--debug'] else logging.INFO - logging.basicConfig(level=log_level, format='%(levelname)s: %(message)s') + log_level = logging.DEBUG if args["--debug"] else logging.INFO + logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s") try: init(args) @@ -525,5 +565,5 @@ def main(): # pragma: no cover sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() # pragma: no cover diff --git a/requirements.txt b/requirements.txt index 1dbf6ab..ba7a159 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ wheel==0.38.1 Yarg==0.1.9 docopt==0.6.2 +nbconvert==7.9.2 + diff --git a/setup.py b/setup.py index f21dc6b..1058477 100755 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ history = history_file.read().replace('.. :changelog:', '') requirements = [ - 'docopt', 'yarg' + 'docopt', 'yarg', 'nbconvert', 'ipython' ] setup( diff --git a/tests/_data_notebook/markdown_test.ipynb b/tests/_data_notebook/markdown_test.ipynb new file mode 100644 index 0000000..54712d3 --- /dev/null +++ b/tests/_data_notebook/markdown_test.ipynb @@ -0,0 +1,37 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Markdown test\n", + "import sklearn\n", + "\n", + "```python\n", + "import FastAPI\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/_data_notebook/models.py b/tests/_data_notebook/models.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/_data_notebook/test.ipynb b/tests/_data_notebook/test.ipynb new file mode 100644 index 0000000..16c07d9 --- /dev/null +++ b/tests/_data_notebook/test.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"unused import\"\"\"\n", + "# pylint: disable=undefined-all-variable, import-error, no-absolute-import, too-few-public-methods, missing-docstring\n", + "import xml.etree # [unused-import]\n", + "import xml.sax # [unused-import]\n", + "import os.path as test # [unused-import]\n", + "from sys import argv as test2 # [unused-import]\n", + "from sys import flags # [unused-import]\n", + "# +1:[unused-import,unused-import]\n", + "from collections import deque, OrderedDict, Counter\n", + "# All imports above should be ignored\n", + "import requests # [unused-import]\n", + "\n", + "# setuptools\n", + "import zipimport # command/easy_install.py\n", + "\n", + "# twisted\n", + "from importlib import invalidate_caches # python/test/test_deprecate.py\n", + "\n", + "# astroid\n", + "import zipimport # manager.py\n", + "# IPython\n", + "from importlib.machinery import all_suffixes # core/completerlib.py\n", + "import importlib # html/notebookapp.py\n", + "\n", + "from IPython.utils.importstring import import_item # Many files\n", + "\n", + "# pyflakes\n", + "# test/test_doctests.py\n", + "from pyflakes.test.test_imports import Test as TestImports\n", + "\n", + "# Nose\n", + "from nose.importer import Importer, add_path, remove_path # loader.py\n", + "\n", + "import atexit\n", + "from __future__ import print_function\n", + "from docopt import docopt\n", + "import curses, logging, sqlite3\n", + "import logging\n", + "import os\n", + "import sqlite3\n", + "import time\n", + "import sys\n", + "import signal\n", + "import bs4\n", + "import nonexistendmodule\n", + "import boto as b, peewee as p\n", + "# import django\n", + "import flask.ext.somext # # #\n", + "from sqlalchemy import model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import ujson as json\n", + "except ImportError:\n", + " import json\n", + "\n", + "import models\n", + "\n", + "\n", + "def main():\n", + " pass\n", + "\n", + "import after_method_is_valid_even_if_not_pep8" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/_invalid_data_notebook/invalid.ipynb b/tests/_invalid_data_notebook/invalid.ipynb new file mode 100644 index 0000000..cacff3f --- /dev/null +++ b/tests/_invalid_data_notebook/invalid.ipynb @@ -0,0 +1,34 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cd ." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index f239f07..a2cbbbe 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -37,7 +37,7 @@ def setUp(self): "after_method_is_valid_even_if_not_pep8", ] self.modules2 = ["beautifulsoup4"] - self.local = ["docopt", "requests", "nose", "pyflakes"] + self.local = ["docopt", "requests", "nose", "pyflakes", "ipython"] self.project = os.path.join(os.path.dirname(__file__), "_data") self.empty_filepath = os.path.join(self.project, "empty.txt") self.imports_filepath = os.path.join(self.project, "imports.txt") @@ -66,19 +66,20 @@ def setUp(self): self.project_clean = os.path.join(os.path.dirname(__file__), "_data_clean") self.project_invalid = os.path.join(os.path.dirname(__file__), "_invalid_data") - self.parsed_packages = [ - {"name": "pandas", "version": "2.0.0"}, - {"name": "numpy", "version": "1.2.3"}, - {"name": "torch", "version": "4.0.0"}, - ] - self.empty_filepath = os.path.join(self.project, "empty.txt") - self.imports_filepath = os.path.join(self.project, "imports.txt") - - self.project_with_ignore_directory = os.path.join(os.path.dirname(__file__), "_data_ignore") - self.project_with_duplicated_deps = os.path.join(os.path.dirname(__file__), "_data_duplicated_deps") - + self.project_with_ignore_directory = os.path.join( + os.path.dirname(__file__), "_data_ignore" + ) + self.project_with_duplicated_deps = os.path.join( + os.path.dirname(__file__), "_data_duplicated_deps" + ) self.requirements_path = os.path.join(self.project, "requirements.txt") self.alt_requirement_path = os.path.join(self.project, "requirements2.txt") + self.project_with_notebooks = os.path.join(os.path.dirname(__file__), "_data_notebook") + self.project_with_invalid_notebooks = os.path.join(os.path.dirname(__file__), "_invalid_data_notebook") + self.compatible_files = { + "original": os.path.join(os.path.dirname(__file__), "_data/test.py"), + "notebook": os.path.join(os.path.dirname(__file__), "_data_notebook/test.ipynb"), + } def test_get_all_imports(self): imports = pipreqs.get_all_imports(self.project) @@ -471,7 +472,7 @@ def test_compare_modules(self): modules_not_imported = pipreqs.compare_modules(filename, imports) self.assertSetEqual(modules_not_imported, expected_modules_not_imported) - + def test_output_requirements(self): """ Test --print parameter @@ -515,6 +516,48 @@ def test_output_requirements(self): stdout_content = capturedOutput.getvalue().lower() self.assertTrue(file_content == stdout_content) + def test_import_notebooks(self): + """ + Test the function get_all_imports() using .ipynb file + """ + imports = pipreqs.get_all_imports(self.project_with_notebooks, encoding="utf-8") + self.assertEqual(len(imports), 13) + for item in imports: + self.assertTrue(item.lower() in self.modules, "Import is missing: " + item) + self.assertFalse("time" in imports) + self.assertFalse("logging" in imports) + self.assertFalse("curses" in imports) + self.assertFalse("__future__" in imports) + self.assertFalse("django" in imports) + self.assertFalse("models" in imports) + self.assertFalse("FastAPI" in imports) + self.assertFalse("sklearn" in imports) + + def test_invalid_notebook(self): + """ + Test that invalid notebook files cannot be imported. + """ + self.assertRaises(SyntaxError, pipreqs.get_all_imports, self.project_with_invalid_notebooks) + + def test_ipynb_2_py(self): + """ + Test the function ipynb_2_py() which converts .ipynb file to .py format + """ + expected = pipreqs.get_all_imports(self.compatible_files["original"]) + parsed = pipreqs.get_all_imports(self.compatible_files["notebook"]) + self.assertEqual(expected, parsed) + + parsed = pipreqs.get_all_imports(self.compatible_files["notebook"], encoding="utf-8") + self.assertEqual(expected, parsed) + + def test_filter_ext(self): + """ + Test the function filter_ext() + """ + self.assertTrue(pipreqs.filter_ext("main.py", [".py"])) + self.assertTrue(pipreqs.filter_ext("main.py", [".py", ".ipynb"])) + self.assertFalse(pipreqs.filter_ext("main.py", [".ipynb"])) + def test_parse_requirements(self): """ Test parse_requirements function