summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r--sqlglot/helper.py97
1 files changed, 33 insertions, 64 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 2f48ab5..a863017 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -14,7 +14,6 @@ from itertools import count
if t.TYPE_CHECKING:
from sqlglot import exp
from sqlglot._typing import E, T
- from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
@@ -23,7 +22,12 @@ logger = logging.getLogger("sqlglot")
class AutoName(Enum):
- """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
+ """
+ This is used for creating Enum classes where `auto()` is the string form
+ of the corresponding enum's identifier (e.g. FOO.value results in "FOO").
+
+ Reference: https://docs.python.org/3/howto/enum.html#using-automatic-values
+ """
def _generate_next_value_(name, _start, _count, _last_values):
return name
@@ -52,7 +56,7 @@ def ensure_list(value):
Ensures that a value is a list, otherwise casts or wraps it into one.
Args:
- value: the value of interest.
+ value: The value of interest.
Returns:
The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
@@ -80,7 +84,7 @@ def ensure_collection(value):
Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
Args:
- value: the value of interest.
+ value: The value of interest.
Returns:
The value if it's a collection, or else the value wrapped in a list.
@@ -97,8 +101,8 @@ def csv(*args: str, sep: str = ", ") -> str:
Formats any number of string arguments as CSV.
Args:
- args: the string arguments to format.
- sep: the argument separator.
+ args: The string arguments to format.
+ sep: The argument separator.
Returns:
The arguments formatted as a CSV string.
@@ -115,9 +119,9 @@ def subclasses(
Returns all subclasses for a collection of classes, possibly excluding some of them.
Args:
- module_name: the name of the module to search for subclasses in.
- classes: class(es) we want to find the subclasses of.
- exclude: class(es) we want to exclude from the returned list.
+ module_name: The name of the module to search for subclasses in.
+ classes: Class(es) we want to find the subclasses of.
+ exclude: Class(es) we want to exclude from the returned list.
Returns:
The target subclasses.
@@ -140,13 +144,13 @@ def apply_index_offset(
Applies an offset to a given integer literal expression.
Args:
- this: the target of the index
- expressions: the expression the offset will be applied to, wrapped in a list.
- offset: the offset that will be applied.
+ this: The target of the index.
+ expressions: The expression the offset will be applied to, wrapped in a list.
+ offset: The offset that will be applied.
Returns:
The original expression with the offset applied to it, wrapped in a list. If the provided
- `expressions` argument contains more than one expressions, it's returned unaffected.
+ `expressions` argument contains more than one expression, it's returned unaffected.
"""
if not offset or len(expressions) != 1:
return expressions
@@ -189,8 +193,8 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
Applies a transformation to a given expression until a fix point is reached.
Args:
- expression: the expression to be transformed.
- func: the transformation to be applied.
+ expression: The expression to be transformed.
+ func: The transformation to be applied.
Returns:
The transformed expression.
@@ -198,6 +202,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
while True:
for n, *_ in reversed(tuple(expression.walk())):
n._hash = hash(n)
+
start = hash(expression)
expression = func(expression)
@@ -205,6 +210,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
n._hash = None
if start == hash(expression):
break
+
return expression
@@ -213,7 +219,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
Sorts a given directed acyclic graph in topological order.
Args:
- dag: the graph to be sorted.
+ dag: The graph to be sorted.
Returns:
A list that contains all of the graph's nodes in topological order.
@@ -261,7 +267,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
- read_csv: a `ReadCSV` function call
+ read_csv: A `ReadCSV` function call.
Yields:
A python csv reader.
@@ -288,8 +294,8 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
Searches for a new name.
Args:
- taken: a collection of taken names.
- base: base name to alter.
+ taken: A collection of taken names.
+ base: Base name to alter.
Returns:
The new, available name.
@@ -327,10 +333,10 @@ def split_num_words(
Perform a split on a value and return N words as a result with `None` used for words that don't exist.
Args:
- value: the value to be split.
- sep: the value to use to split on.
- min_num_words: the minimum number of words that are going to be in the result.
- fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
+ value: The value to be split.
+ sep: The value to use to split on.
+ min_num_words: The minimum number of words that are going to be in the result.
+ fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
@@ -360,7 +366,7 @@ def is_iterable(value: t.Any) -> bool:
False
Args:
- value: the value to check if it is an iterable.
+ value: The value to check if it is an iterable.
Returns:
A `bool` value indicating if it is an iterable.
@@ -380,7 +386,7 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
[1, 2, 3]
Args:
- values: the value to be flattened.
+ values: The value to be flattened.
Yields:
Non-iterable elements in `values`.
@@ -396,7 +402,7 @@ def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
- For example:
+ Example:
>>> dict_depth(None)
0
>>> dict_depth({})
@@ -407,12 +413,6 @@ def dict_depth(d: t.Dict) -> int:
2
>>> dict_depth({"a": {"b": {}}})
3
-
- Args:
- d (dict): dictionary
-
- Returns:
- int: depth
"""
try:
return 1 + dict_depth(next(iter(d.values())))
@@ -425,36 +425,5 @@ def dict_depth(d: t.Dict) -> int:
def first(it: t.Iterable[T]) -> T:
- """Returns the first element from an iterable.
-
- Useful for sets.
- """
+ """Returns the first element from an iterable (useful for sets)."""
return next(i for i in it)
-
-
-def case_sensitive(text: str, dialect: DialectType) -> bool:
- """Checks if text contains any case sensitive characters depending on dialect."""
- from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
-
- unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
- return any(unsafe(char) for char in text)
-
-
-def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
- """Checks if text should be identified given an identify option.
-
- Args:
- text: the text to check.
- identify:
- "always" or `True`: always returns true.
- "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
- dialect: the dialect to use in order to decide whether a text should be identified.
-
- Returns:
- Whether or not a string should be identified.
- """
- if identify is True or identify == "always":
- return True
- if identify == "safe":
- return not case_sensitive(text, dialect)
- return False