summaryrefslogtreecommitdiffstats
path: root/tests/fix_encoding_pragma_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/fix_encoding_pragma_test.py')
-rw-r--r--tests/fix_encoding_pragma_test.py161
1 files changed, 161 insertions, 0 deletions
diff --git a/tests/fix_encoding_pragma_test.py b/tests/fix_encoding_pragma_test.py
new file mode 100644
index 0000000..98557e9
--- /dev/null
+++ b/tests/fix_encoding_pragma_test.py
@@ -0,0 +1,161 @@
+from __future__ import annotations
+
+import io
+
+import pytest
+
+from pre_commit_hooks.fix_encoding_pragma import _normalize_pragma
+from pre_commit_hooks.fix_encoding_pragma import fix_encoding_pragma
+from pre_commit_hooks.fix_encoding_pragma import main
+
+
+def test_integration_inserting_pragma(tmpdir):
+ path = tmpdir.join('foo.py')
+ path.write_binary(b'import httplib\n')
+
+ assert main((str(path),)) == 1
+
+ assert path.read_binary() == (
+ b'# -*- coding: utf-8 -*-\n'
+ b'import httplib\n'
+ )
+
+
+def test_integration_ok(tmpdir):
+ path = tmpdir.join('foo.py')
+ path.write_binary(b'# -*- coding: utf-8 -*-\nx = 1\n')
+ assert main((str(path),)) == 0
+
+
+def test_integration_remove(tmpdir):
+ path = tmpdir.join('foo.py')
+ path.write_binary(b'# -*- coding: utf-8 -*-\nx = 1\n')
+
+ assert main((str(path), '--remove')) == 1
+
+ assert path.read_binary() == b'x = 1\n'
+
+
+def test_integration_remove_ok(tmpdir):
+ path = tmpdir.join('foo.py')
+ path.write_binary(b'x = 1\n')
+ assert main((str(path), '--remove')) == 0
+
+
+@pytest.mark.parametrize(
+ 'input_str',
+ (
+ b'',
+ (
+ b'# -*- coding: utf-8 -*-\n'
+ b'x = 1\n'
+ ),
+ (
+ b'#!/usr/bin/env python\n'
+ b'# -*- coding: utf-8 -*-\n'
+ b'foo = "bar"\n'
+ ),
+ ),
+)
+def test_ok_inputs(input_str):
+ bytesio = io.BytesIO(input_str)
+ assert fix_encoding_pragma(bytesio) == 0
+ bytesio.seek(0)
+ assert bytesio.read() == input_str
+
+
+@pytest.mark.parametrize(
+ ('input_str', 'output'),
+ (
+ (
+ b'import httplib\n',
+ b'# -*- coding: utf-8 -*-\n'
+ b'import httplib\n',
+ ),
+ (
+ b'#!/usr/bin/env python\n'
+ b'x = 1\n',
+ b'#!/usr/bin/env python\n'
+ b'# -*- coding: utf-8 -*-\n'
+ b'x = 1\n',
+ ),
+ (
+ b'#coding=utf-8\n'
+ b'x = 1\n',
+ b'# -*- coding: utf-8 -*-\n'
+ b'x = 1\n',
+ ),
+ (
+ b'#!/usr/bin/env python\n'
+ b'#coding=utf8\n'
+ b'x = 1\n',
+ b'#!/usr/bin/env python\n'
+ b'# -*- coding: utf-8 -*-\n'
+ b'x = 1\n',
+ ),
+ # These should each get truncated
+ (b'#coding: utf-8\n', b''),
+ (b'# -*- coding: utf-8 -*-\n', b''),
+ (b'#!/usr/bin/env python\n', b''),
+ (b'#!/usr/bin/env python\n#coding: utf8\n', b''),
+ (b'#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n', b''),
+ ),
+)
+def test_not_ok_inputs(input_str, output):
+ bytesio = io.BytesIO(input_str)
+ assert fix_encoding_pragma(bytesio) == 1
+ bytesio.seek(0)
+ assert bytesio.read() == output
+
+
+def test_ok_input_alternate_pragma():
+ input_s = b'# coding: utf-8\nx = 1\n'
+ bytesio = io.BytesIO(input_s)
+ ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
+ assert ret == 0
+ bytesio.seek(0)
+ assert bytesio.read() == input_s
+
+
+def test_not_ok_input_alternate_pragma():
+ bytesio = io.BytesIO(b'x = 1\n')
+ ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
+ assert ret == 1
+ bytesio.seek(0)
+ assert bytesio.read() == b'# coding: utf-8\nx = 1\n'
+
+
+@pytest.mark.parametrize(
+ ('input_s', 'expected'),
+ (
+ ('# coding: utf-8', b'# coding: utf-8'),
+ # trailing whitespace
+ ('# coding: utf-8\n', b'# coding: utf-8'),
+ ),
+)
+def test_normalize_pragma(input_s, expected):
+ assert _normalize_pragma(input_s) == expected
+
+
+def test_integration_alternate_pragma(tmpdir, capsys):
+ f = tmpdir.join('f.py')
+ f.write('x = 1\n')
+
+ pragma = '# coding: utf-8'
+ assert main((str(f), '--pragma', pragma)) == 1
+ assert f.read() == '# coding: utf-8\nx = 1\n'
+ out, _ = capsys.readouterr()
+ assert out == f'Added `# coding: utf-8` to {str(f)}\n'
+
+
+def test_crlf_ok(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write_binary(b'# -*- coding: utf-8 -*-\r\nx = 1\r\n')
+ assert not main((str(f),))
+
+
+def test_crfl_adds(tmpdir):
+ f = tmpdir.join('f.py')
+ f.write_binary(b'x = 1\r\n')
+ assert main((str(f),))
+ assert f.read_binary() == b'# -*- coding: utf-8 -*-\r\nx = 1\r\n'