diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/dialect.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 220 |
1 files changed, 159 insertions, 61 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 21e7889..c7cea64 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,14 +1,14 @@ from __future__ import annotations import typing as t -from enum import Enum +from enum import Enum, auto from functools import reduce from sqlglot import exp from sqlglot._typing import E from sqlglot.errors import ParseError from sqlglot.generator import Generator -from sqlglot.helper import flatten, seq_get +from sqlglot.helper import AutoName, flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType @@ -16,6 +16,9 @@ from sqlglot.trie import new_trie B = t.TypeVar("B", bound=exp.Binary) +DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] +DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] + class Dialects(str, Enum): DIALECT = "" @@ -43,6 +46,15 @@ class Dialects(str, Enum): Doris = "doris" +class NormalizationStrategy(str, AutoName): + """Specifies the strategy according to which identifiers should be normalized.""" + + LOWERCASE = auto() # Unquoted identifiers are lowercased + UPPERCASE = auto() # Unquoted identifiers are uppercased + CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes + CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes + + class _Dialect(type): classes: t.Dict[str, t.Type[Dialect]] = {} @@ -106,26 +118,8 @@ class _Dialect(type): klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) - dialect_properties = { - **{ - k: v - for k, v in vars(klass).items() - if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") - }, - "TOKENIZER_CLASS": klass.tokenizer_class, - } - if enum not in ("", "bigquery"): - dialect_properties["SELECT_KINDS"] = () - - # Pass required dialect properties to the tokenizer, parser and generator classes - for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): - for name, value in dialect_properties.items(): - if hasattr(subclass, name): - setattr(subclass, name, value) - - if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: - klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + klass.generator_class.SELECT_KINDS = () if not klass.SUPPORTS_SEMI_ANTI_JOIN: klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { @@ -133,8 +127,6 @@ class _Dialect(type): TokenType.SEMI, } - klass.generator_class.can_identify = klass.can_identify - return klass @@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect): # Determines whether or not the table alias comes after tablesample ALIAS_POST_TABLESAMPLE = False - # Determines whether or not unquoted identifiers are resolved as uppercase - # When set to None, it means that the dialect treats all identifiers as case-insensitive - RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False + # Specifies the strategy according to which identifiers should be normalized. + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE # Determines whether or not an unquoted identifier can start with a digit IDENTIFIERS_CAN_START_WITH_DIGIT = False @@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect): # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" NULL_ORDERING = "nulls_are_small" + # Whether the behavior of a / b depends on the types of a and b. + # False means a / b is always float division. + # True means a / b is integer division if both a and b are integers. + TYPED_DIVISION = False + + # False means 1 / 0 throws an error. + # True means 1 / 0 returns null. + SAFE_DIVISION = False + + # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string + CONCAT_COALESCE = False + DATE_FORMAT = "'%Y-%m-%d'" DATEINT_FORMAT = "'%Y%m%d'" TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" @@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect): # Such columns may be excluded from SELECT * queries, for example PSEUDOCOLUMNS: t.Set[str] = set() - # Autofilled + # --- Autofilled --- + tokenizer_class = Tokenizer parser_class = Parser generator_class = Generator @@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect): INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} - def __eq__(self, other: t.Any) -> bool: - return type(self) == other + # Delimiters for quotes, identifiers and the corresponding escape characters + QUOTE_START = "'" + QUOTE_END = "'" + IDENTIFIER_START = '"' + IDENTIFIER_END = '"' - def __hash__(self) -> int: - return hash(type(self)) + # Delimiters for bit, hex and byte literals + BIT_START: t.Optional[str] = None + BIT_END: t.Optional[str] = None + HEX_START: t.Optional[str] = None + HEX_END: t.Optional[str] = None + BYTE_START: t.Optional[str] = None + BYTE_END: t.Optional[str] = None @classmethod - def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: + def get_or_raise(cls, dialect: DialectType) -> Dialect: + """ + Look up a dialect in the global dialect registry and return it if it exists. + + Args: + dialect: The target dialect. If this is a string, it can be optionally followed by + additional key-value pairs that are separated by commas and are used to specify + dialect settings, such as whether the dialect's identifiers are case-sensitive. + + Example: + >>> dialect = dialect_class = get_or_raise("duckdb") + >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") + + Returns: + The corresponding Dialect instance. + """ + if not dialect: - return cls + return cls() if isinstance(dialect, _Dialect): - return dialect + return dialect() if isinstance(dialect, Dialect): - return dialect.__class__ + return dialect + if isinstance(dialect, str): + try: + dialect_name, *kv_pairs = dialect.split(",") + kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)} + except ValueError: + raise ValueError( + f"Invalid dialect format: '{dialect}'. " + "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." + ) + + result = cls.get(dialect_name.strip()) + if not result: + raise ValueError(f"Unknown dialect '{dialect_name}'.") - result = cls.get(dialect) - if not result: - raise ValueError(f"Unknown dialect '{dialect}'") + return result(**kwargs) - return result + raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") @classmethod def format_time( @@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect): return expression - @classmethod - def normalize_identifier(cls, expression: E) -> E: + def __init__(self, **kwargs) -> None: + normalization_strategy = kwargs.get("normalization_strategy") + + if normalization_strategy is None: + self.normalization_strategy = self.NORMALIZATION_STRATEGY + else: + self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) + + def __eq__(self, other: t.Any) -> bool: + # Does not currently take dialect state into account + return type(self) == other + + def __hash__(self) -> int: + # Does not currently take dialect state into account + return hash(type(self)) + + def normalize_identifier(self, expression: E) -> E: """ - Normalizes an unquoted identifier to either lower or upper case, thus essentially - making it case-insensitive. If a dialect treats all identifiers as case-insensitive, - they will be normalized to lowercase regardless of being quoted or not. + Transforms an identifier in a way that resembles how it'd be resolved by this dialect. + + For example, an identifier like FoO would be resolved as foo in Postgres, because it + lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so + it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive, + and so any normalization would be prohibited in order to avoid "breaking" the identifier. + + There are also dialects like Spark, which are case-insensitive even when quotes are + present, and dialects like MySQL, whose resolution rules match those employed by the + underlying operating system, for example they may always be case-sensitive in Linux. + + Finally, the normalization behavior of some engines can even be controlled through flags, + like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. + + SQLGlot aims to understand and handle all of these different behaviors gracefully, so + that it can analyze queries in the optimizer and successfully capture their semantics. """ - if isinstance(expression, exp.Identifier) and ( - not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None + if ( + isinstance(expression, exp.Identifier) + and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE + and ( + not expression.quoted + or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE + ) ): expression.set( "this", expression.this.upper() - if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE + if self.normalization_strategy is NormalizationStrategy.UPPERCASE else expression.this.lower(), ) return expression - @classmethod - def case_sensitive(cls, text: str) -> bool: + def case_sensitive(self, text: str) -> bool: """Checks if text contains any case sensitive characters, based on the dialect's rules.""" - if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: + if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: return False - unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + unsafe = ( + str.islower + if self.normalization_strategy is NormalizationStrategy.UPPERCASE + else str.isupper + ) return any(unsafe(char) for char in text) - @classmethod - def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: + def can_identify(self, text: str, identify: str | bool = "safe") -> bool: """Checks if text can be identified given an identify option. Args: @@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect): return True if identify == "safe": - return not cls.case_sensitive(text) + return not self.case_sensitive(text) return False - @classmethod - def quote_identifier(cls, expression: E, identify: bool = True) -> E: + def quote_identifier(self, expression: E, identify: bool = True) -> E: if isinstance(expression, exp.Identifier): name = expression.this expression.set( "quoted", - identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), + identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), ) return expression @@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect): @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() + self._tokenizer = self.tokenizer_class(dialect=self) return self._tokenizer def parser(self, **opts) -> Parser: - return self.parser_class(**opts) + return self.parser_class(dialect=self, **opts) def generator(self, **opts) -> Generator: - return self.generator_class(**opts) + return self.generator_class(dialect=self, **opts) DialectType = t.Union[str, Dialect, t.Type[Dialect], None] @@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: return _ts_or_ds_to_date_sql -def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: +def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) @@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex return self.func(name, expression.this, expression.expression) return _arg_max_or_min_sql + + +def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: + this = expression.this.copy() + + return_type = expression.return_type + if return_type.is_type(exp.DataType.Type.DATE): + # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we + # can truncate timestamp strings, because some dialects can't cast them to DATE + this = exp.cast(this, exp.DataType.Type.TIMESTAMP) + + expression.this.replace(exp.cast(this, return_type)) + return expression + + +def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: + def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: + if cast and isinstance(expression, exp.TsOrDsAdd): + expression = ts_or_ds_add_cast(expression) + + return self.func( + name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this + ) + + return _delta_sql |