From 67578a7602a5be7eb51f324086c8d49bcf8b7498 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 16 Jun 2023 11:41:18 +0200 Subject: Merging upstream version 16.2.1. Signed-off-by: Daniel Baumann --- sqlglot/dialects/dialect.py | 201 ++++++++++++++++++++++---------------------- 1 file changed, 101 insertions(+), 100 deletions(-) (limited to 'sqlglot/dialects/dialect.py') diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4958bc6..f5d523b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -25,6 +25,8 @@ class Dialects(str, Enum): BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" + DATABRICKS = "databricks" + DRILL = "drill" DUCKDB = "duckdb" HIVE = "hive" MYSQL = "mysql" @@ -38,11 +40,9 @@ class Dialects(str, Enum): SQLITE = "sqlite" STARROCKS = "starrocks" TABLEAU = "tableau" + TERADATA = "teradata" TRINO = "trino" TSQL = "tsql" - DATABRICKS = "databricks" - DRILL = "drill" - TERADATA = "teradata" class _Dialect(type): @@ -76,16 +76,19 @@ class _Dialect(type): enum = Dialects.__members__.get(clsname.upper()) cls.classes[enum.value if enum is not None else clsname.lower()] = klass - klass.time_trie = new_trie(klass.time_mapping) - klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} - klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) + klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) + klass.FORMAT_TRIE = ( + new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE + ) + klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} + klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) klass.parser_class = getattr(klass, "Parser", Parser) klass.generator_class = getattr(klass, "Generator", Generator) - klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list( + klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] + klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( klass.tokenizer_class._IDENTIFIERS.items() )[0] @@ -99,43 +102,80 @@ class _Dialect(type): (None, None), ) - klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) - klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) - klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) - klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) + klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) + klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) + klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) + klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING) - klass.tokenizer_class.identifiers_can_start_with_digit = ( - klass.identifiers_can_start_with_digit - ) + dialect_properties = { + **{ + k: v + for k, v in vars(klass).items() + if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") + }, + "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], + "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], + } + + # Pass required dialect properties to the tokenizer, parser and generator classes + for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): + for name, value in dialect_properties.items(): + if hasattr(subclass, name): + setattr(subclass, name, value) + + if not klass.STRICT_STRING_CONCAT: + klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe return klass class Dialect(metaclass=_Dialect): - index_offset = 0 - unnest_column_only = False - alias_post_tablesample = False - identifiers_can_start_with_digit = False - normalize_functions: t.Optional[str] = "upper" - null_ordering = "nulls_are_small" - - date_format = "'%Y-%m-%d'" - dateint_format = "'%Y%m%d'" - time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping: t.Dict[str, str] = {} - - # autofilled - quote_start = None - quote_end = None - identifier_start = None - identifier_end = None - - time_trie = None - inverse_time_mapping = None - inverse_time_trie = None - tokenizer_class = None - parser_class = None - generator_class = None + # Determines the base index offset for arrays + INDEX_OFFSET = 0 + + # If true unnest table aliases are considered only as column aliases + UNNEST_COLUMN_ONLY = False + + # Determines whether or not the table alias comes after tablesample + ALIAS_POST_TABLESAMPLE = False + + # Determines whether or not an unquoted identifier can start with a digit + IDENTIFIERS_CAN_START_WITH_DIGIT = False + + # Determines whether or not CONCAT's arguments must be strings + STRICT_STRING_CONCAT = False + + # Determines how function names are going to be normalized + NORMALIZE_FUNCTIONS: bool | str = "upper" + + # Indicates the default null ordering method to use if not explicitly set + # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last" + NULL_ORDERING = "nulls_are_small" + + DATE_FORMAT = "'%Y-%m-%d'" + DATEINT_FORMAT = "'%Y%m%d'" + TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" + + # Custom time mappings in which the key represents dialect time format + # and the value represents a python time format + TIME_MAPPING: t.Dict[str, str] = {} + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time + # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE + # special syntax cast(x as date format 'yyyy') defaults to time_mapping + FORMAT_MAPPING: t.Dict[str, str] = {} + + # Autofilled + tokenizer_class = Tokenizer + parser_class = Parser + generator_class = Generator + + # A trie of the time_mapping keys + TIME_TRIE: t.Dict = {} + FORMAT_TRIE: t.Dict = {} + + INVERSE_TIME_MAPPING: t.Dict[str, str] = {} + INVERSE_TIME_TRIE: t.Dict = {} def __eq__(self, other: t.Any) -> bool: return type(self) == other @@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect): ) -> t.Optional[exp.Expression]: if isinstance(expression, str): return exp.Literal.string( - format_time( - expression[1:-1], # the time formats are quoted - cls.time_mapping, - cls.time_trie, - ) + # the time formats are quoted + format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) ) + if expression and expression.is_string: - return exp.Literal.string( - format_time( - expression.this, - cls.time_mapping, - cls.time_trie, - ) - ) + return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) + return expression def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: @@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect): @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): - self._tokenizer = self.tokenizer_class() # type: ignore + self._tokenizer = self.tokenizer_class() return self._tokenizer def parser(self, **opts) -> Parser: - return self.parser_class( # type: ignore - **{ - "index_offset": self.index_offset, - "unnest_column_only": self.unnest_column_only, - "alias_post_tablesample": self.alias_post_tablesample, - "null_ordering": self.null_ordering, - **opts, - }, - ) + return self.parser_class(**opts) def generator(self, **opts) -> Generator: - return self.generator_class( # type: ignore - **{ - "quote_start": self.quote_start, - "quote_end": self.quote_end, - "bit_start": self.bit_start, - "bit_end": self.bit_end, - "hex_start": self.hex_start, - "hex_end": self.hex_end, - "byte_start": self.byte_start, - "byte_end": self.byte_end, - "raw_start": self.raw_start, - "raw_end": self.raw_end, - "identifier_start": self.identifier_start, - "identifier_end": self.identifier_end, - "string_escape": self.tokenizer_class.STRING_ESCAPES[0], - "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0], - "index_offset": self.index_offset, - "time_mapping": self.inverse_time_mapping, - "time_trie": self.inverse_time_trie, - "unnest_column_only": self.unnest_column_only, - "alias_post_tablesample": self.alias_post_tablesample, - "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit, - "normalize_functions": self.normalize_functions, - "null_ordering": self.null_ordering, - **opts, - } - ) + return self.generator_class(**opts) DialectType = t.Union[str, Dialect, t.Type[Dialect], None] @@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like( - this=exp.Lower(this=expression.this), - expression=expression.args["expression"], - ) + exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) ) @@ -359,6 +355,7 @@ def var_map_sql( for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) + return self.func(map_func_name, *args) @@ -381,7 +378,7 @@ def format_time_lambda( this=seq_get(args, 0), format=Dialect[dialect].format_time( seq_get(args, 1) - or (Dialect[dialect].time_format if default is True else default or None) + or (Dialect[dialect].TIME_FORMAT if default is True else default or None) ), ) @@ -437,9 +434,7 @@ def parse_date_delta_with_interval( expression = exp.Literal.number(expression.this) return expression_class( - this=args[0], - expression=expression, - unit=exp.Literal.string(interval.text("unit")), + this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) ) return func @@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), + this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) ) @@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: _dialect = Dialect.get_or_raise(dialect) time_format = self.format_time(expression) - if time_format and time_format not in (_dialect.time_format, _dialect.date_format): + if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): return f"CAST({str_to_time_sql(self, expression)} AS DATE)" return f"CAST({self.sql(expression, 'this')} AS DATE)" return _ts_or_ds_to_date_sql +def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: + this, *rest_args = expression.expressions + for arg in rest_args: + this = exp.DPipe(this=this, expression=arg) + + return self.sql(this) + + # Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] -- cgit v1.2.3