diff options
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r-- | sqlglot/helper.py | 47 |
1 files changed, 29 insertions, 18 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py index b2f0520..4215fee 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -9,14 +9,14 @@ from collections.abc import Collection from contextlib import contextmanager from copy import copy from enum import Enum +from itertools import count if t.TYPE_CHECKING: from sqlglot import exp + from sqlglot._typing import E, T + from sqlglot.dialects.dialect import DialectType from sqlglot.expressions import Expression - T = t.TypeVar("T") - E = t.TypeVar("E", bound=Expression) - CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") PYTHON_VERSION = sys.version_info[:2] logger = logging.getLogger("sqlglot") @@ -25,7 +25,7 @@ logger = logging.getLogger("sqlglot") class AutoName(Enum): """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name.""" - def _generate_next_value_(name, _start, _count, _last_values): # type: ignore + def _generate_next_value_(name, _start, _count, _last_values): return name @@ -92,7 +92,7 @@ def ensure_collection(value): ) -def csv(*args, sep: str = ", ") -> str: +def csv(*args: str, sep: str = ", ") -> str: """ Formats any number of string arguments as CSV. @@ -304,9 +304,18 @@ def find_new_name(taken: t.Collection[str], base: str) -> str: return new +def name_sequence(prefix: str) -> t.Callable[[], str]: + """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" + sequence = count() + return lambda: f"{prefix}{next(sequence)}" + + def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: """Returns a dictionary created from an object's attributes.""" - return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} + return { + **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, + **kwargs, + } def split_num_words( @@ -381,15 +390,6 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: yield value -def count_params(function: t.Callable) -> int: - """ - Returns the number of formal parameters expected by a function, without counting "self" - and "cls", in case of instance and class methods, respectively. - """ - count = function.__code__.co_argcount - return count - 1 if inspect.ismethod(function) else count - - def dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. @@ -430,12 +430,23 @@ def first(it: t.Iterable[T]) -> T: return next(i for i in it) -def should_identify(text: str, identify: str | bool) -> bool: +def case_sensitive(text: str, dialect: DialectType) -> bool: + """Checks if text contains any case sensitive characters depending on dialect.""" + from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE + + unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + return any(unsafe(char) for char in text) + + +def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool: """Checks if text should be identified given an identify option. Args: text: the text to check. - identify: "always" | True - always returns true, "safe" - true if no upper case + identify: + "always" or `True`: always returns true. + "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`. + dialect: the dialect to use in order to decide whether a text should be identified. Returns: Whether or not a string should be identified. @@ -443,5 +454,5 @@ def should_identify(text: str, identify: str | bool) -> bool: if identify is True or identify == "always": return True if identify == "safe": - return not any(char.isupper() for char in text) + return not case_sensitive(text, dialect) return False |