Skip to content

Commit

Permalink
handle FileNotFoundError in parse_requirements function
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuslatrova committed Oct 9, 2023
1 parent 3c786e3 commit 4bb7f3e
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 125 deletions.
176 changes: 78 additions & 98 deletions pipreqs/pipreqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,11 @@

from pipreqs import __version__

REGEXP = [
re.compile(r'^import (.+)$'),
re.compile(r'^from ((?!\.+).*?) import (?:.*)$')
]
REGEXP = [re.compile(r"^import (.+)$"), re.compile(r"^from ((?!\.+).*?) import (?:.*)$")]


@contextmanager
def _open(filename=None, mode='r'):
def _open(filename=None, mode="r"):
"""Open a file or ``sys.stdout`` depending on the provided filename.
Args:
Expand All @@ -70,13 +67,13 @@ def _open(filename=None, mode='r'):
A file handle.
"""
if not filename or filename == '-':
if not mode or 'r' in mode:
if not filename or filename == "-":
if not mode or "r" in mode:
file = sys.stdin
elif 'w' in mode:
elif "w" in mode:
file = sys.stdout
else:
raise ValueError('Invalid mode for file: {}'.format(mode))
raise ValueError("Invalid mode for file: {}".format(mode))
else:
file = open(filename, mode)

Expand All @@ -87,8 +84,7 @@ def _open(filename=None, mode='r'):
file.close()


def get_all_imports(
path, encoding=None, extra_ignore_dirs=None, follow_links=True):
def get_all_imports(path, encoding=None, extra_ignore_dirs=None, follow_links=True):
imports = set()
raw_imports = set()
candidates = []
Expand Down Expand Up @@ -137,11 +133,11 @@ def get_all_imports(
# Cleanup: We only want to first part of the import.
# Ex: from django.conf --> django.conf. But we only want django
# as an import.
cleaned_name, _, _ = name.partition('.')
cleaned_name, _, _ = name.partition(".")
imports.add(cleaned_name)

packages = imports - (set(candidates) & imports)
logging.debug('Found packages: {0}'.format(packages))
logging.debug("Found packages: {0}".format(packages))

with open(join("stdlib"), "r") as f:
data = {x.strip() for x in f}
Expand All @@ -155,56 +151,48 @@ def filter_line(line):

def generate_requirements_file(path, imports, symbol):
with _open(path, "w") as out_file:
logging.debug('Writing {num} requirements: {imports} to {file}'.format(
num=len(imports),
file=path,
imports=", ".join([x['name'] for x in imports])
))
fmt = '{name}' + symbol + '{version}'
out_file.write('\n'.join(
fmt.format(**item) if item['version'] else '{name}'.format(**item)
for item in imports) + '\n')
logging.debug(
"Writing {num} requirements: {imports} to {file}".format(
num=len(imports), file=path, imports=", ".join([x["name"] for x in imports])
)
)
fmt = "{name}" + symbol + "{version}"
out_file.write(
"\n".join(fmt.format(**item) if item["version"] else "{name}".format(**item) for item in imports) + "\n"
)


def output_requirements(imports, symbol):
generate_requirements_file('-', imports, symbol)
generate_requirements_file("-", imports, symbol)


def get_imports_info(
imports, pypi_server="https://pypi.python.org/pypi/", proxy=None):
def get_imports_info(imports, pypi_server="https://pypi.python.org/pypi/", proxy=None):
result = []

for item in imports:
try:
logging.warning(
'Import named "%s" not found locally. '
'Trying to resolve it at the PyPI server.',
item
)
response = requests.get(
"{0}{1}/json".format(pypi_server, item), proxies=proxy)
logging.warning('Import named "%s" not found locally. ' "Trying to resolve it at the PyPI server.", item)
response = requests.get("{0}{1}/json".format(pypi_server, item), proxies=proxy)
if response.status_code == 200:
if hasattr(response.content, 'decode'):
if hasattr(response.content, "decode"):
data = json2package(response.content.decode())
else:
data = json2package(response.content)
elif response.status_code >= 300:
raise HTTPError(status_code=response.status_code,
reason=response.reason)
raise HTTPError(status_code=response.status_code, reason=response.reason)
except HTTPError:
logging.warning(
'Package "%s" does not exist or network problems', item)
logging.warning('Package "%s" does not exist or network problems', item)
continue
logging.warning(
'Import named "%s" was resolved to "%s:%s" package (%s).\n'
'Please, verify manually the final list of requirements.txt '
'to avoid possible dependency confusions.',
"Please, verify manually the final list of requirements.txt "
"to avoid possible dependency confusions.",
item,
data.name,
data.latest_release_id,
data.pypi_url
data.pypi_url,
)
result.append({'name': item, 'version': data.latest_release_id})
result.append({"name": item, "version": data.latest_release_id})
return result


Expand All @@ -229,25 +217,17 @@ def get_locally_installed_packages(encoding=None):
filtered_top_level_modules = list()

for module in top_level_modules:
if (
(module not in ignore) and
(package[0] not in ignore)
):
if (module not in ignore) and (package[0] not in ignore):
# append exported top level modules to the list
filtered_top_level_modules.append(module)

version = None
if len(package) > 1:
version = package[1].replace(
".dist", "").replace(".egg", "")
version = package[1].replace(".dist", "").replace(".egg", "")

# append package: top_level_modules pairs
# instead of top_level_module: package pairs
packages.append({
'name': package[0],
'version': version,
'exports': filtered_top_level_modules
})
packages.append({"name": package[0], "version": version, "exports": filtered_top_level_modules})
return packages


Expand All @@ -260,14 +240,14 @@ def get_import_local(imports, encoding=None):
# if candidate import name matches export name
# or candidate import name equals to the package name
# append it to the result
if item in package['exports'] or item == package['name']:
if item in package["exports"] or item == package["name"]:
result.append(package)

# removing duplicates of package/version
# had to use second method instead of the previous one,
# because we have a list in the 'exports' field
# https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python
result_unique = [i for n, i in enumerate(result) if i not in result[n+1:]]
result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]]

return result_unique

Expand Down Expand Up @@ -298,7 +278,7 @@ def get_name_without_alias(name):
match = REGEXP[0].match(name.strip())
if match:
name = match.groups(0)[0]
return name.partition(' as ')[0].partition('.')[0].strip()
return name.partition(" as ")[0].partition(".")[0].strip()


def join(f):
Expand All @@ -312,6 +292,9 @@ def parse_requirements(file_):
delimiter, get module name by element index, create a dict consisting of
module:version, and add dict to list of parsed modules.
If file ´file_´ is not found in the system, the program will print a
helpful message and end its execution immediately.
Args:
file_: File to parse.
Expand All @@ -328,9 +311,12 @@ def parse_requirements(file_):

try:
f = open(file_, "r")
except OSError:
logging.error("Failed on file: {}".format(file_))
raise
except FileNotFoundError:
print(f"File {file_} was not found. Please, fix it and run again.")
sys.exit(1)
except OSError as error:
logging.error(f"There was an error opening the file {file_}: {str(error)}")
raise error
else:
try:
data = [x.strip() for x in f.readlines() if x != "\n"]
Expand Down Expand Up @@ -384,7 +370,8 @@ def diff(file_, imports):

logging.info(
"The following modules are in {} but do not seem to be imported: "
"{}".format(file_, ", ".join(x for x in modules_not_imported)))
"{}".format(file_, ", ".join(x for x in modules_not_imported))
)


def clean(file_, imports):
Expand Down Expand Up @@ -432,30 +419,24 @@ def dynamic_versioning(scheme, imports):


def init(args):
encoding = args.get('--encoding')
extra_ignore_dirs = args.get('--ignore')
follow_links = not args.get('--no-follow-links')
input_path = args['<path>']
encoding = args.get("--encoding")
extra_ignore_dirs = args.get("--ignore")
follow_links = not args.get("--no-follow-links")
input_path = args["<path>"]
if input_path is None:
input_path = os.path.abspath(os.curdir)

if extra_ignore_dirs:
extra_ignore_dirs = extra_ignore_dirs.split(',')

path = (args["--savepath"] if args["--savepath"] else
os.path.join(input_path, "requirements.txt"))
if (not args["--print"]
and not args["--savepath"]
and not args["--force"]
and os.path.exists(path)):
logging.warning("requirements.txt already exists, "
"use --force to overwrite it")
extra_ignore_dirs = extra_ignore_dirs.split(",")

path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path):
logging.warning("requirements.txt already exists, " "use --force to overwrite it")
return

candidates = get_all_imports(input_path,
encoding=encoding,
extra_ignore_dirs=extra_ignore_dirs,
follow_links=follow_links)
candidates = get_all_imports(
input_path, encoding=encoding, extra_ignore_dirs=extra_ignore_dirs, follow_links=follow_links
)
candidates = get_pkg_names(candidates)
logging.debug("Found imports: " + ", ".join(candidates))
pypi_server = "https://pypi.python.org/pypi/"
Expand All @@ -464,11 +445,10 @@ def init(args):
pypi_server = args["--pypi-server"]

if args["--proxy"]:
proxy = {'http': args["--proxy"], 'https': args["--proxy"]}
proxy = {"http": args["--proxy"], "https": args["--proxy"]}

if args["--use-local"]:
logging.debug(
"Getting package information ONLY from local installation.")
logging.debug("Getting package information ONLY from local installation.")
imports = get_import_local(candidates, encoding=encoding)
else:
logging.debug("Getting packages information from Local/PyPI")
Expand All @@ -478,20 +458,21 @@ def init(args):
# the list of exported modules, installed locally
# and the package name is not in the list of local module names
# it add to difference
difference = [x for x in candidates if
# aggregate all export lists into one
# flatten the list
# check if candidate is in exports
x.lower() not in [y for x in local for y in x['exports']]
and
# check if candidate is package names
x.lower() not in [x['name'] for x in local]]

imports = local + get_imports_info(difference,
proxy=proxy,
pypi_server=pypi_server)
difference = [
x
for x in candidates
if
# aggregate all export lists into one
# flatten the list
# check if candidate is in exports
x.lower() not in [y for x in local for y in x["exports"]] and
# check if candidate is package names
x.lower() not in [x["name"] for x in local]
]

imports = local + get_imports_info(difference, proxy=proxy, pypi_server=pypi_server)
# sort imports based on lowercase name of package, similar to `pip freeze`.
imports = sorted(imports, key=lambda x: x['name'].lower())
imports = sorted(imports, key=lambda x: x["name"].lower())

if args["--diff"]:
diff(args["--diff"], imports)
Expand All @@ -506,8 +487,7 @@ def init(args):
if scheme in ["compat", "gt", "no-pin"]:
imports, symbol = dynamic_versioning(scheme, imports)
else:
raise ValueError("Invalid argument for mode flag, "
"use 'compat', 'gt' or 'no-pin' instead")
raise ValueError("Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead")
else:
symbol = "=="

Expand All @@ -521,14 +501,14 @@ def init(args):

def main(): # pragma: no cover
args = docopt(__doc__, version=__version__)
log_level = logging.DEBUG if args['--debug'] else logging.INFO
logging.basicConfig(level=log_level, format='%(levelname)s: %(message)s')
log_level = logging.DEBUG if args["--debug"] else logging.INFO
logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s")

try:
init(args)
except KeyboardInterrupt:
sys.exit(0)


if __name__ == '__main__':
if __name__ == "__main__":
main() # pragma: no cover
Loading

0 comments on commit 4bb7f3e

Please sign in to comment.