summaryrefslogtreecommitdiffstats
path: root/pre_commit_hooks/string_fixer.py
diff options
context:
space:
mode:
Diffstat (limited to 'pre_commit_hooks/string_fixer.py')
-rw-r--r--pre_commit_hooks/string_fixer.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py
new file mode 100644
index 0000000..d1b1c4a
--- /dev/null
+++ b/pre_commit_hooks/string_fixer.py
@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+import argparse
+import io
+import re
+import sys
+import tokenize
+from typing import Sequence
+
+if sys.version_info >= (3, 12): # pragma: >=3.12 cover
+ FSTRING_START = tokenize.FSTRING_START
+ FSTRING_END = tokenize.FSTRING_END
+else: # pragma: <3.12 cover
+ FSTRING_START = FSTRING_END = -1
+
+START_QUOTE_RE = re.compile('^[a-zA-Z]*"')
+
+
+def handle_match(token_text: str) -> str:
+ if '"""' in token_text or "'''" in token_text:
+ return token_text
+
+ match = START_QUOTE_RE.match(token_text)
+ if match is not None:
+ meat = token_text[match.end():-1]
+ if '"' in meat or "'" in meat:
+ return token_text
+ else:
+ return match.group().replace('"', "'") + meat + "'"
+ else:
+ return token_text
+
+
+def get_line_offsets_by_line_no(src: str) -> list[int]:
+ # Padded so we can index with line number
+ offsets = [-1, 0]
+ for line in src.splitlines(True):
+ offsets.append(offsets[-1] + len(line))
+ return offsets
+
+
+def fix_strings(filename: str) -> int:
+ with open(filename, encoding='UTF-8', newline='') as f:
+ contents = f.read()
+ line_offsets = get_line_offsets_by_line_no(contents)
+
+ # Basically a mutable string
+ splitcontents = list(contents)
+
+ fstring_depth = 0
+
+ # Iterate in reverse so the offsets are always correct
+ tokens_l = list(tokenize.generate_tokens(io.StringIO(contents).readline))
+ tokens = reversed(tokens_l)
+ for token_type, token_text, (srow, scol), (erow, ecol), _ in tokens:
+ if token_type == FSTRING_START: # pragma: >=3.12 cover
+ fstring_depth += 1
+ elif token_type == FSTRING_END: # pragma: >=3.12 cover
+ fstring_depth -= 1
+ elif fstring_depth == 0 and token_type == tokenize.STRING:
+ new_text = handle_match(token_text)
+ splitcontents[
+ line_offsets[srow] + scol:
+ line_offsets[erow] + ecol
+ ] = new_text
+
+ new_contents = ''.join(splitcontents)
+ if contents != new_contents:
+ with open(filename, 'w', encoding='UTF-8', newline='') as f:
+ f.write(new_contents)
+ return 1
+ else:
+ return 0
+
+
+def main(argv: Sequence[str] | None = None) -> int:
+ parser = argparse.ArgumentParser()
+ parser.add_argument('filenames', nargs='*', help='Filenames to fix')
+ args = parser.parse_args(argv)
+
+ retv = 0
+
+ for filename in args.filenames:
+ return_value = fix_strings(filename)
+ if return_value != 0:
+ print(f'Fixing strings in {filename}')
+ retv |= return_value
+
+ return retv
+
+
+if __name__ == '__main__':
+ raise SystemExit(main())