summaryrefslogtreecommitdiffstats
path: root/pre_commit/languages/pygrep.py
blob: 72a9345fafe5cd205d12e35d2909049f97b1a177 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

import argparse
import re
import sys
from collections.abc import Sequence
from re import Pattern
from typing import NamedTuple

from pre_commit import lang_base
from pre_commit import output
from pre_commit.prefix import Prefix
from pre_commit.xargs import xargs

ENVIRONMENT_DIR = None
get_default_version = lang_base.basic_get_default_version
health_check = lang_base.basic_health_check
install_environment = lang_base.no_install
in_env = lang_base.no_env


def _process_filename_by_line(pattern: Pattern[bytes], filename: str) -> int:
    retv = 0
    with open(filename, 'rb') as f:
        for line_no, line in enumerate(f, start=1):
            if pattern.search(line):
                retv = 1
                output.write(f'{filename}:{line_no}:')
                output.write_line_b(line.rstrip(b'\r\n'))
    return retv


def _process_filename_at_once(pattern: Pattern[bytes], filename: str) -> int:
    retv = 0
    with open(filename, 'rb') as f:
        contents = f.read()
        match = pattern.search(contents)
        if match:
            retv = 1
            line_no = contents[:match.start()].count(b'\n')
            output.write(f'{filename}:{line_no + 1}:')

            matched_lines = match[0].split(b'\n')
            matched_lines[0] = contents.split(b'\n')[line_no]

            output.write_line_b(b'\n'.join(matched_lines))
    return retv


def _process_filename_by_line_negated(
        pattern: Pattern[bytes],
        filename: str,
) -> int:
    with open(filename, 'rb') as f:
        for line in f:
            if pattern.search(line):
                return 0
        else:
            output.write_line(filename)
            return 1


def _process_filename_at_once_negated(
        pattern: Pattern[bytes],
        filename: str,
) -> int:
    with open(filename, 'rb') as f:
        contents = f.read()
    match = pattern.search(contents)
    if match:
        return 0
    else:
        output.write_line(filename)
        return 1


class Choice(NamedTuple):
    multiline: bool
    negate: bool


FNS = {
    Choice(multiline=True, negate=True): _process_filename_at_once_negated,
    Choice(multiline=True, negate=False): _process_filename_at_once,
    Choice(multiline=False, negate=True): _process_filename_by_line_negated,
    Choice(multiline=False, negate=False): _process_filename_by_line,
}


def run_hook(
        prefix: Prefix,
        entry: str,
        args: Sequence[str],
        file_args: Sequence[str],
        *,
        is_local: bool,
        require_serial: bool,
        color: bool,
) -> tuple[int, bytes]:
    cmd = (sys.executable, '-m', __name__, *args, entry)
    return xargs(cmd, file_args, color=color)


def main(argv: Sequence[str] | None = None) -> int:
    parser = argparse.ArgumentParser(
        description=(
            'grep-like finder using python regexes.  Unlike grep, this tool '
            'returns nonzero when it finds a match and zero otherwise.  The '
            'idea here being that matches are "problems".'
        ),
    )
    parser.add_argument('-i', '--ignore-case', action='store_true')
    parser.add_argument('--multiline', action='store_true')
    parser.add_argument('--negate', action='store_true')
    parser.add_argument('pattern', help='python regex pattern.')
    parser.add_argument('filenames', nargs='*')
    args = parser.parse_args(argv)

    flags = re.IGNORECASE if args.ignore_case else 0
    if args.multiline:
        flags |= re.MULTILINE | re.DOTALL

    pattern = re.compile(args.pattern.encode(), flags)

    retv = 0
    process_fn = FNS[Choice(multiline=args.multiline, negate=args.negate)]
    for filename in args.filenames:
        retv |= process_fn(pattern, filename)
    return retv


if __name__ == '__main__':
    raise SystemExit(main())