summaryrefslogtreecommitdiffstats
path: root/tests/fix_encoding_pragma_test.py
blob: 98557e9813be0cfeba42d52f46a3165a05d1217b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from __future__ import annotations

import io

import pytest

from pre_commit_hooks.fix_encoding_pragma import _normalize_pragma
from pre_commit_hooks.fix_encoding_pragma import fix_encoding_pragma
from pre_commit_hooks.fix_encoding_pragma import main


def test_integration_inserting_pragma(tmpdir):
    path = tmpdir.join('foo.py')
    path.write_binary(b'import httplib\n')

    assert main((str(path),)) == 1

    assert path.read_binary() == (
        b'# -*- coding: utf-8 -*-\n'
        b'import httplib\n'
    )


def test_integration_ok(tmpdir):
    path = tmpdir.join('foo.py')
    path.write_binary(b'# -*- coding: utf-8 -*-\nx = 1\n')
    assert main((str(path),)) == 0


def test_integration_remove(tmpdir):
    path = tmpdir.join('foo.py')
    path.write_binary(b'# -*- coding: utf-8 -*-\nx = 1\n')

    assert main((str(path), '--remove')) == 1

    assert path.read_binary() == b'x = 1\n'


def test_integration_remove_ok(tmpdir):
    path = tmpdir.join('foo.py')
    path.write_binary(b'x = 1\n')
    assert main((str(path), '--remove')) == 0


@pytest.mark.parametrize(
    'input_str',
    (
        b'',
        (
            b'# -*- coding: utf-8 -*-\n'
            b'x = 1\n'
        ),
        (
            b'#!/usr/bin/env python\n'
            b'# -*- coding: utf-8 -*-\n'
            b'foo = "bar"\n'
        ),
    ),
)
def test_ok_inputs(input_str):
    bytesio = io.BytesIO(input_str)
    assert fix_encoding_pragma(bytesio) == 0
    bytesio.seek(0)
    assert bytesio.read() == input_str


@pytest.mark.parametrize(
    ('input_str', 'output'),
    (
        (
            b'import httplib\n',
            b'# -*- coding: utf-8 -*-\n'
            b'import httplib\n',
        ),
        (
            b'#!/usr/bin/env python\n'
            b'x = 1\n',
            b'#!/usr/bin/env python\n'
            b'# -*- coding: utf-8 -*-\n'
            b'x = 1\n',
        ),
        (
            b'#coding=utf-8\n'
            b'x = 1\n',
            b'# -*- coding: utf-8 -*-\n'
            b'x = 1\n',
        ),
        (
            b'#!/usr/bin/env python\n'
            b'#coding=utf8\n'
            b'x = 1\n',
            b'#!/usr/bin/env python\n'
            b'# -*- coding: utf-8 -*-\n'
            b'x = 1\n',
        ),
        # These should each get truncated
        (b'#coding: utf-8\n', b''),
        (b'# -*- coding: utf-8 -*-\n', b''),
        (b'#!/usr/bin/env python\n', b''),
        (b'#!/usr/bin/env python\n#coding: utf8\n', b''),
        (b'#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n', b''),
    ),
)
def test_not_ok_inputs(input_str, output):
    bytesio = io.BytesIO(input_str)
    assert fix_encoding_pragma(bytesio) == 1
    bytesio.seek(0)
    assert bytesio.read() == output


def test_ok_input_alternate_pragma():
    input_s = b'# coding: utf-8\nx = 1\n'
    bytesio = io.BytesIO(input_s)
    ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
    assert ret == 0
    bytesio.seek(0)
    assert bytesio.read() == input_s


def test_not_ok_input_alternate_pragma():
    bytesio = io.BytesIO(b'x = 1\n')
    ret = fix_encoding_pragma(bytesio, expected_pragma=b'# coding: utf-8')
    assert ret == 1
    bytesio.seek(0)
    assert bytesio.read() == b'# coding: utf-8\nx = 1\n'


@pytest.mark.parametrize(
    ('input_s', 'expected'),
    (
        ('# coding: utf-8', b'# coding: utf-8'),
        # trailing whitespace
        ('# coding: utf-8\n', b'# coding: utf-8'),
    ),
)
def test_normalize_pragma(input_s, expected):
    assert _normalize_pragma(input_s) == expected


def test_integration_alternate_pragma(tmpdir, capsys):
    f = tmpdir.join('f.py')
    f.write('x = 1\n')

    pragma = '# coding: utf-8'
    assert main((str(f), '--pragma', pragma)) == 1
    assert f.read() == '# coding: utf-8\nx = 1\n'
    out, _ = capsys.readouterr()
    assert out == f'Added `# coding: utf-8` to {str(f)}\n'


def test_crlf_ok(tmpdir):
    f = tmpdir.join('f.py')
    f.write_binary(b'# -*- coding: utf-8 -*-\r\nx = 1\r\n')
    assert not main((str(f),))


def test_crfl_adds(tmpdir):
    f = tmpdir.join('f.py')
    f.write_binary(b'x = 1\r\n')
    assert main((str(f),))
    assert f.read_binary() == b'# -*- coding: utf-8 -*-\r\nx = 1\r\n'