diff options
Diffstat (limited to 'pre_commit_hooks/string_fixer.py')
-rw-r--r-- | pre_commit_hooks/string_fixer.py | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py new file mode 100644 index 0000000..d1b1c4a --- /dev/null +++ b/pre_commit_hooks/string_fixer.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import argparse +import io +import re +import sys +import tokenize +from typing import Sequence + +if sys.version_info >= (3, 12): # pragma: >=3.12 cover + FSTRING_START = tokenize.FSTRING_START + FSTRING_END = tokenize.FSTRING_END +else: # pragma: <3.12 cover + FSTRING_START = FSTRING_END = -1 + +START_QUOTE_RE = re.compile('^[a-zA-Z]*"') + + +def handle_match(token_text: str) -> str: + if '"""' in token_text or "'''" in token_text: + return token_text + + match = START_QUOTE_RE.match(token_text) + if match is not None: + meat = token_text[match.end():-1] + if '"' in meat or "'" in meat: + return token_text + else: + return match.group().replace('"', "'") + meat + "'" + else: + return token_text + + +def get_line_offsets_by_line_no(src: str) -> list[int]: + # Padded so we can index with line number + offsets = [-1, 0] + for line in src.splitlines(True): + offsets.append(offsets[-1] + len(line)) + return offsets + + +def fix_strings(filename: str) -> int: + with open(filename, encoding='UTF-8', newline='') as f: + contents = f.read() + line_offsets = get_line_offsets_by_line_no(contents) + + # Basically a mutable string + splitcontents = list(contents) + + fstring_depth = 0 + + # Iterate in reverse so the offsets are always correct + tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline)) + tokens = reversed(tokens_l) + for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens: + if token_type == FSTRING_START: # pragma: >=3.12 cover + fstring_depth += 1 + elif token_type == FSTRING_END: # pragma: >=3.12 cover + fstring_depth -= 1 + elif fstring_depth == 0 and token_type == tokenize.STRING: + new_text = handle_match(token_text) + splitcontents[ + line_offsets[srow] + scol: + line_offsets[erow] + ecol + ] = new_text + + new_contents = ''.join(splitcontents) + if contents != new_contents: + with open(filename, 'w', encoding='UTF-8', newline='') as f: + f.write(new_contents) + return 1 + else: + return 0 + + +def main(argv: Sequence[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument('filenames', nargs='*', help='Filenames to fix') + args = parser.parse_args(argv) + + retv = 0 + + for filename in args.filenames: + return_value = fix_strings(filename) + if return_value != 0: + print(f'Fixing strings in {filename}') + retv |= return_value + + return retv + + +if __name__ == '__main__': + raise SystemExit(main()) |