summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/dialect.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py220
1 files changed, 159 insertions, 61 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 21e7889..c7cea64 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,14 +1,14 @@
from __future__ import annotations
import typing as t
-from enum import Enum
+from enum import Enum, auto
from functools import reduce
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
-from sqlglot.helper import flatten, seq_get
+from sqlglot.helper import AutoName, flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@@ -16,6 +16,9 @@ from sqlglot.trie import new_trie
B = t.TypeVar("B", bound=exp.Binary)
+DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
+DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
+
class Dialects(str, Enum):
DIALECT = ""
@@ -43,6 +46,15 @@ class Dialects(str, Enum):
Doris = "doris"
+class NormalizationStrategy(str, AutoName):
+ """Specifies the strategy according to which identifiers should be normalized."""
+
+ LOWERCASE = auto() # Unquoted identifiers are lowercased
+ UPPERCASE = auto() # Unquoted identifiers are uppercased
+ CASE_SENSITIVE = auto() # Always case-sensitive, regardless of quotes
+ CASE_INSENSITIVE = auto() # Always case-insensitive, regardless of quotes
+
+
class _Dialect(type):
classes: t.Dict[str, t.Type[Dialect]] = {}
@@ -106,26 +118,8 @@ class _Dialect(type):
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
- dialect_properties = {
- **{
- k: v
- for k, v in vars(klass).items()
- if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
- },
- "TOKENIZER_CLASS": klass.tokenizer_class,
- }
-
if enum not in ("", "bigquery"):
- dialect_properties["SELECT_KINDS"] = ()
-
- # Pass required dialect properties to the tokenizer, parser and generator classes
- for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
- for name, value in dialect_properties.items():
- if hasattr(subclass, name):
- setattr(subclass, name, value)
-
- if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT:
- klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
+ klass.generator_class.SELECT_KINDS = ()
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
@@ -133,8 +127,6 @@ class _Dialect(type):
TokenType.SEMI,
}
- klass.generator_class.can_identify = klass.can_identify
-
return klass
@@ -148,9 +140,8 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False
- # Determines whether or not unquoted identifiers are resolved as uppercase
- # When set to None, it means that the dialect treats all identifiers as case-insensitive
- RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
+ # Specifies the strategy according to which identifiers should be normalized.
+ NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False
@@ -177,6 +168,18 @@ class Dialect(metaclass=_Dialect):
# Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
NULL_ORDERING = "nulls_are_small"
+ # Whether the behavior of a / b depends on the types of a and b.
+ # False means a / b is always float division.
+ # True means a / b is integer division if both a and b are integers.
+ TYPED_DIVISION = False
+
+ # False means 1 / 0 throws an error.
+ # True means 1 / 0 returns null.
+ SAFE_DIVISION = False
+
+ # A NULL arg in CONCAT yields NULL by default, but in some dialects it yields an empty string
+ CONCAT_COALESCE = False
+
DATE_FORMAT = "'%Y-%m-%d'"
DATEINT_FORMAT = "'%Y%m%d'"
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
@@ -197,7 +200,8 @@ class Dialect(metaclass=_Dialect):
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
- # Autofilled
+ # --- Autofilled ---
+
tokenizer_class = Tokenizer
parser_class = Parser
generator_class = Generator
@@ -211,26 +215,61 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- def __eq__(self, other: t.Any) -> bool:
- return type(self) == other
+ # Delimiters for quotes, identifiers and the corresponding escape characters
+ QUOTE_START = "'"
+ QUOTE_END = "'"
+ IDENTIFIER_START = '"'
+ IDENTIFIER_END = '"'
- def __hash__(self) -> int:
- return hash(type(self))
+ # Delimiters for bit, hex and byte literals
+ BIT_START: t.Optional[str] = None
+ BIT_END: t.Optional[str] = None
+ HEX_START: t.Optional[str] = None
+ HEX_END: t.Optional[str] = None
+ BYTE_START: t.Optional[str] = None
+ BYTE_END: t.Optional[str] = None
@classmethod
- def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]:
+ def get_or_raise(cls, dialect: DialectType) -> Dialect:
+ """
+ Look up a dialect in the global dialect registry and return it if it exists.
+
+ Args:
+ dialect: The target dialect. If this is a string, it can be optionally followed by
+ additional key-value pairs that are separated by commas and are used to specify
+ dialect settings, such as whether the dialect's identifiers are case-sensitive.
+
+ Example:
+ >>> dialect = dialect_class = get_or_raise("duckdb")
+ >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
+
+ Returns:
+ The corresponding Dialect instance.
+ """
+
if not dialect:
- return cls
+ return cls()
if isinstance(dialect, _Dialect):
- return dialect
+ return dialect()
if isinstance(dialect, Dialect):
- return dialect.__class__
+ return dialect
+ if isinstance(dialect, str):
+ try:
+ dialect_name, *kv_pairs = dialect.split(",")
+ kwargs = {k.strip(): v.strip() for k, v in (kv.split("=") for kv in kv_pairs)}
+ except ValueError:
+ raise ValueError(
+ f"Invalid dialect format: '{dialect}'. "
+ "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'."
+ )
+
+ result = cls.get(dialect_name.strip())
+ if not result:
+ raise ValueError(f"Unknown dialect '{dialect_name}'.")
- result = cls.get(dialect)
- if not result:
- raise ValueError(f"Unknown dialect '{dialect}'")
+ return result(**kwargs)
- return result
+ raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
@classmethod
def format_time(
@@ -247,36 +286,71 @@ class Dialect(metaclass=_Dialect):
return expression
- @classmethod
- def normalize_identifier(cls, expression: E) -> E:
+ def __init__(self, **kwargs) -> None:
+ normalization_strategy = kwargs.get("normalization_strategy")
+
+ if normalization_strategy is None:
+ self.normalization_strategy = self.NORMALIZATION_STRATEGY
+ else:
+ self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper())
+
+ def __eq__(self, other: t.Any) -> bool:
+ # Does not currently take dialect state into account
+ return type(self) == other
+
+ def __hash__(self) -> int:
+ # Does not currently take dialect state into account
+ return hash(type(self))
+
+ def normalize_identifier(self, expression: E) -> E:
"""
- Normalizes an unquoted identifier to either lower or upper case, thus essentially
- making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
- they will be normalized to lowercase regardless of being quoted or not.
+ Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
+
+ For example, an identifier like FoO would be resolved as foo in Postgres, because it
+ lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
+ it would resolve it as FOO. If it was quoted, it'd need to be treated as case-sensitive,
+ and so any normalization would be prohibited in order to avoid "breaking" the identifier.
+
+ There are also dialects like Spark, which are case-insensitive even when quotes are
+ present, and dialects like MySQL, whose resolution rules match those employed by the
+ underlying operating system, for example they may always be case-sensitive in Linux.
+
+ Finally, the normalization behavior of some engines can even be controlled through flags,
+ like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
+
+ SQLGlot aims to understand and handle all of these different behaviors gracefully, so
+ that it can analyze queries in the optimizer and successfully capture their semantics.
"""
- if isinstance(expression, exp.Identifier) and (
- not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
+ if (
+ isinstance(expression, exp.Identifier)
+ and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
+ and (
+ not expression.quoted
+ or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
+ )
):
expression.set(
"this",
expression.this.upper()
- if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
else expression.this.lower(),
)
return expression
- @classmethod
- def case_sensitive(cls, text: str) -> bool:
+ def case_sensitive(self, text: str) -> bool:
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
- if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
+ if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE:
return False
- unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
+ unsafe = (
+ str.islower
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
+ else str.isupper
+ )
return any(unsafe(char) for char in text)
- @classmethod
- def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
+ def can_identify(self, text: str, identify: str | bool = "safe") -> bool:
"""Checks if text can be identified given an identify option.
Args:
@@ -292,17 +366,16 @@ class Dialect(metaclass=_Dialect):
return True
if identify == "safe":
- return not cls.case_sensitive(text)
+ return not self.case_sensitive(text)
return False
- @classmethod
- def quote_identifier(cls, expression: E, identify: bool = True) -> E:
+ def quote_identifier(self, expression: E, identify: bool = True) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
- identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
+ identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
)
return expression
@@ -330,14 +403,14 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
- self._tokenizer = self.tokenizer_class()
+ self._tokenizer = self.tokenizer_class(dialect=self)
return self._tokenizer
def parser(self, **opts) -> Parser:
- return self.parser_class(**opts)
+ return self.parser_class(dialect=self, **opts)
def generator(self, **opts) -> Generator:
- return self.generator_class(**opts)
+ return self.generator_class(dialect=self, **opts)
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
@@ -713,7 +786,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
return _ts_or_ds_to_date_sql
-def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
+def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@@ -821,3 +894,28 @@ def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | ex
return self.func(name, expression.this, expression.expression)
return _arg_max_or_min_sql
+
+
+def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd:
+ this = expression.this.copy()
+
+ return_type = expression.return_type
+ if return_type.is_type(exp.DataType.Type.DATE):
+ # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we
+ # can truncate timestamp strings, because some dialects can't cast them to DATE
+ this = exp.cast(this, exp.DataType.Type.TIMESTAMP)
+
+ expression.this.replace(exp.cast(this, return_type))
+ return expression
+
+
+def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]:
+ def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str:
+ if cast and isinstance(expression, exp.TsOrDsAdd):
+ expression = ts_or_ds_add_cast(expression)
+
+ return self.func(
+ name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
+ )
+
+ return _delta_sql