summaryrefslogtreecommitdiffstats
path: root/sqlglot/helper.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/helper.py')
-rw-r--r--sqlglot/helper.py47
1 files changed, 29 insertions, 18 deletions
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index b2f0520..4215fee 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -9,14 +9,14 @@ from collections.abc import Collection
from contextlib import contextmanager
from copy import copy
from enum import Enum
+from itertools import count
if t.TYPE_CHECKING:
from sqlglot import exp
+ from sqlglot._typing import E, T
+ from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression
- T = t.TypeVar("T")
- E = t.TypeVar("E", bound=Expression)
-
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
@@ -25,7 +25,7 @@ logger = logging.getLogger("sqlglot")
class AutoName(Enum):
"""This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
- def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
+ def _generate_next_value_(name, _start, _count, _last_values):
return name
@@ -92,7 +92,7 @@ def ensure_collection(value):
)
-def csv(*args, sep: str = ", ") -> str:
+def csv(*args: str, sep: str = ", ") -> str:
"""
Formats any number of string arguments as CSV.
@@ -304,9 +304,18 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:
return new
+def name_sequence(prefix: str) -> t.Callable[[], str]:
+ """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a")."""
+ sequence = count()
+ return lambda: f"{prefix}{next(sequence)}"
+
+
def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
- return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
+ return {
+ **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
+ **kwargs,
+ }
def split_num_words(
@@ -381,15 +390,6 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]:
yield value
-def count_params(function: t.Callable) -> int:
- """
- Returns the number of formal parameters expected by a function, without counting "self"
- and "cls", in case of instance and class methods, respectively.
- """
- count = function.__code__.co_argcount
- return count - 1 if inspect.ismethod(function) else count
-
-
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
@@ -430,12 +430,23 @@ def first(it: t.Iterable[T]) -> T:
return next(i for i in it)
-def should_identify(text: str, identify: str | bool) -> bool:
+def case_sensitive(text: str, dialect: DialectType) -> bool:
+ """Checks if text contains any case sensitive characters depending on dialect."""
+ from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE
+
+ unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
+ return any(unsafe(char) for char in text)
+
+
+def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
"""Checks if text should be identified given an identify option.
Args:
text: the text to check.
- identify: "always" | True - always returns true, "safe" - true if no upper case
+ identify:
+ "always" or `True`: always returns true.
+ "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
+ dialect: the dialect to use in order to decide whether a text should be identified.
Returns:
Whether or not a string should be identified.
@@ -443,5 +454,5 @@ def should_identify(text: str, identify: str | bool) -> bool:
if identify is True or identify == "always":
return True
if identify == "safe":
- return not any(char.isupper() for char in text)
+ return not case_sensitive(text, dialect)
return False