diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/helper.py | 97 |
1 files changed, 33 insertions, 64 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 2f48ab5..a863017 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -14,7 +14,6 @@ from itertools import count if t.TYPE_CHECKING: from sqlglot import exp from sqlglot._typing import E, T - from sqlglot.dialects.dialect import DialectType from sqlglot.expressions import Expression CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") @@ -23,7 +22,12 @@ logger = logging.getLogger("sqlglot") class AutoName(Enum): - """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name.""" + """ + This is used for creating Enum classes where `auto()` is the string form + of the corresponding enum's identifier (e.g. FOO.value results in "FOO"). + + Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values + """ def _generate_next_value_(name, _start, _count, _last_values): return name @@ -52,7 +56,7 @@ def ensure_list(value): Ensures that a value is a list, otherwise casts or wraps it into one. Args: - value: the value of interest. + value: The value of interest. Returns: The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. @@ -80,7 +84,7 @@ def ensure_collection(value): Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. Args: - value: the value of interest. + value: The value of interest. Returns: The value if it's a collection, or else the value wrapped in a list. @@ -97,8 +101,8 @@ def csv(*args: str, sep: str = ", ") -> str: Formats any number of string arguments as CSV. Args: - args: the string arguments to format. - sep: the argument separator. + args: The string arguments to format. + sep: The argument separator. Returns: The arguments formatted as a CSV string. @@ -115,9 +119,9 @@ def subclasses( Returns all subclasses for a collection of classes, possibly excluding some of them. Args: - module_name: the name of the module to search for subclasses in. - classes: class(es) we want to find the subclasses of. - exclude: class(es) we want to exclude from the returned list. + module_name: The name of the module to search for subclasses in. + classes: Class(es) we want to find the subclasses of. + exclude: Class(es) we want to exclude from the returned list. Returns: The target subclasses. @@ -140,13 +144,13 @@ def apply_index_offset( Applies an offset to a given integer literal expression. Args: - this: the target of the index - expressions: the expression the offset will be applied to, wrapped in a list. - offset: the offset that will be applied. + this: The target of the index. + expressions: The expression the offset will be applied to, wrapped in a list. + offset: The offset that will be applied. Returns: The original expression with the offset applied to it, wrapped in a list. If the provided - `expressions` argument contains more than one expressions, it's returned unaffected. + `expressions` argument contains more than one expression, it's returned unaffected. """ if not offset or len(expressions) != 1: return expressions @@ -189,8 +193,8 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> Applies a transformation to a given expression until a fix point is reached. Args: - expression: the expression to be transformed. - func: the transformation to be applied. + expression: The expression to be transformed. + func: The transformation to be applied. Returns: The transformed expression. @@ -198,6 +202,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> while True: for n, *_ in reversed(tuple(expression.walk())): n._hash = hash(n) + start = hash(expression) expression = func(expression) @@ -205,6 +210,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> n._hash = None if start == hash(expression): break + return expression @@ -213,7 +219,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: Sorts a given directed acyclic graph in topological order. Args: - dag: the graph to be sorted. + dag: The graph to be sorted. Returns: A list that contains all of the graph's nodes in topological order. @@ -261,7 +267,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any: Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. Args: - read_csv: a `ReadCSV` function call + read_csv: A `ReadCSV` function call. Yields: A python csv reader. @@ -288,8 +294,8 @@ def find_new_name(taken: t.Collection[str], base: str) -> str: Searches for a new name. Args: - taken: a collection of taken names. - base: base name to alter. + taken: A collection of taken names. + base: Base name to alter. Returns: The new, available name. @@ -327,10 +333,10 @@ def split_num_words( Perform a split on a value and return N words as a result with `None` used for words that don't exist. Args: - value: the value to be split. - sep: the value to use to split on. - min_num_words: the minimum number of words that are going to be in the result. - fill_from_start: indicates that if `None` values should be inserted at the start or end of the list. + value: The value to be split. + sep: The value to use to split on. + min_num_words: The minimum number of words that are going to be in the result. + fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. Examples: >>> split_num_words("db.table", ".", 3) @@ -360,7 +366,7 @@ def is_iterable(value: t.Any) -> bool: False Args: - value: the value to check if it is an iterable. + value: The value to check if it is an iterable. Returns: A `bool` value indicating if it is an iterable. @@ -380,7 +386,7 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: [1, 2, 3] Args: - values: the value to be flattened. + values: The value to be flattened. Yields: Non-iterable elements in `values`. @@ -396,7 +402,7 @@ def dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. - For example: + Example: >>> dict_depth(None) 0 >>> dict_depth({}) @@ -407,12 +413,6 @@ def dict_depth(d: t.Dict) -> int: 2 >>> dict_depth({"a": {"b": {}}}) 3 - - Args: - d (dict): dictionary - - Returns: - int: depth """ try: return 1 + dict_depth(next(iter(d.values()))) @@ -425,36 +425,5 @@ def dict_depth(d: t.Dict) -> int: def first(it: t.Iterable[T]) -> T: - """Returns the first element from an iterable. - - Useful for sets. - """ + """Returns the first element from an iterable (useful for sets).""" return next(i for i in it) - - -def case_sensitive(text: str, dialect: DialectType) -> bool: - """Checks if text contains any case sensitive characters depending on dialect.""" - from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE - - unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper - return any(unsafe(char) for char in text) - - -def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool: - """Checks if text should be identified given an identify option. - - Args: - text: the text to check. - identify: - "always" or `True`: always returns true. - "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`. - dialect: the dialect to use in order to decide whether a text should be identified. - - Returns: - Whether or not a string should be identified. - """ - if identify is True or identify == "always": - return True - if identify == "safe": - return not case_sensitive(text, dialect) - return False |