diff --git a/pyformat.py b/pyformat.py index de5e33b..c916891 100755 --- a/pyformat.py +++ b/pyformat.py @@ -31,6 +31,7 @@ import io import signal import sys +from typing import Tuple import autoflake import autopep8 @@ -76,16 +77,43 @@ def format_code(source, aggressive=False, apply_config=False, filename='', return formatted_source +def detect_io_encoding(input_file: io.BytesIO, limit_byte_check=-1): + """Return file encoding.""" + try: + from lib2to3.pgen2 import tokenize as lib2to3_tokenize + encoding: str = lib2to3_tokenize.detect_encoding(input_file.readline)[ + 0] + + input_file.read(limit_byte_check).decode(encoding) + + return encoding + except (LookupError, SyntaxError, UnicodeDecodeError): + return 'latin-1' + + +def read_file(filename: str) -> Tuple[str, str]: + """Read file from filesystem or from stdin when `-` is given.""" + + if is_stdin(filename): + data = sys.stdin.buffer.read() + else: + with open(filename, 'rb') as fp: + data = fp.read() + input_file = io.BytesIO(data) + encoding = detect_io_encoding(input_file) + return data.decode(encoding), encoding + + +def is_stdin(filename: str): + return filename == '-' + + def format_file(filename, args, standard_out): """Run format_code() on a file. Return True if the new formatting differs from the original. - """ - encoding = autopep8.detect_encoding(filename) - with autopep8.open_with_encoding(filename, - encoding=encoding) as input_file: - source = input_file.read() + source, encoding = read_file(filename) if not source: return False @@ -98,6 +126,12 @@ def format_file(filename, args, standard_out): remove_all_unused_imports=args.remove_all_unused_imports, remove_unused_variables=args.remove_unused_variables) + # Always write to stdout (even when no changes were made) when working with + # in-place stdin. This is what most tools (editors) expect. + if args.in_place and is_stdin(filename): + standard_out.write(formatted_source) + return True + if source != formatted_source: if args.in_place: with autopep8.open_with_encoding(filename, mode='w', @@ -142,7 +176,6 @@ def format_multiple_files(filenames, args, standard_out, standard_error): """Format files and return booleans (any_changes, any_errors). Optionally format files recursively. - """ filenames = autopep8.find_files(list(filenames), args.recursive, @@ -211,7 +244,6 @@ def _main(argv, standard_out, standard_error): """Internal main entry point. Return exit status. 0 means no error. - """ args = parse_args(argv) diff --git a/test_pyformat.py b/test_pyformat.py index dbe5fdb..6337db9 100755 --- a/test_pyformat.py +++ b/test_pyformat.py @@ -13,6 +13,7 @@ import sys import tempfile import unittest +from unittest import mock import pyformat @@ -144,6 +145,58 @@ def test_format_multiple_files_with_nonexistent_file_and_verbose(self): class TestSystem(unittest.TestCase): + def test_diff_stdin(self): + input_b = b'''\ +import os +x = "abc" +''' + stdin = io.TextIOWrapper(io.BytesIO(input_b), 'UTF-8') + with mock.patch.object(sys, 'stdin', stdin): + output_file = io.StringIO() + pyformat._main(argv=['my_fake_program', '-'], + standard_out=output_file, + standard_error=None) + self.assertEqual('''\ +@@ -1,2 +1,2 @@ + import os +-x = "abc" ++x = 'abc' +''', '\n'.join(output_file.getvalue().split('\n')[2:])) + + def test_in_place_stdin(self): + input_b = b'''\ +import os +x = "abc" +''' + stdin = io.TextIOWrapper(io.BytesIO(input_b), 'UTF-8') + with mock.patch.object(sys, 'stdin', stdin): + output_file = io.StringIO() + pyformat._main(argv=['my_fake_program', '--in-place', '-'], + standard_out=output_file, + standard_error=None) + self.assertEqual('''\ +import os +x = 'abc' +''', '\n'.join(output_file.getvalue().split('\n'))) + + + def test_in_place_stdin_no_change(self): + input_b = b'''\ +import os +x = 'abc' +''' + stdin = io.TextIOWrapper(io.BytesIO(input_b), 'UTF-8') + with mock.patch.object(sys, 'stdin', stdin): + output_file = io.StringIO() + pyformat._main(argv=['my_fake_program', '--in-place', '-'], + standard_out=output_file, + standard_error=None) + self.assertEqual('''\ +import os +x = 'abc' +''', '\n'.join(output_file.getvalue().split('\n'))) + + def test_diff(self): with temporary_file('''\ import os