From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- sqlglot/dialects/dialect.py | 52 ++++++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 17 deletions(-) (limited to 'sqlglot/dialects/dialect.py') diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 33985a7..3af08bb 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,8 +1,11 @@ +from __future__ import annotations + +import typing as t from enum import Enum from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import flatten, list_get +from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time from sqlglot.tokens import Tokenizer @@ -32,7 +35,7 @@ class Dialects(str, Enum): class _Dialect(type): - classes = {} + classes: t.Dict[str, Dialect] = {} @classmethod def __getitem__(cls, key): @@ -56,19 +59,30 @@ class _Dialect(type): klass.generator_class = getattr(klass, "Generator", Generator) klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] - - if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class._IDENTIFIERS.items() + )[0] + + if ( + klass.tokenizer_class._BIT_STRINGS + and exp.BitString not in klass.generator_class.TRANSFORMS + ): bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.BitString ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" - if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._HEX_STRINGS + and exp.HexString not in klass.generator_class.TRANSFORMS + ): hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.HexString ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" - if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._BYTE_STRINGS + and exp.ByteString not in klass.generator_class.TRANSFORMS + ): be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.ByteString @@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect): index_offset = 0 unnest_column_only = False alias_post_tablesample = False - normalize_functions = "upper" + normalize_functions: t.Optional[str] = "upper" null_ordering = "nulls_are_small" date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping = {} + time_mapping: t.Dict[str, str] = {} # autofilled quote_start = None @@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect): "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPE, + "escape": self.tokenizer_class.ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression): def if_sql(self, expression): - expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false")) + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) return f"IF({expressions})" @@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None): def _format_time(args): return exp_class( - this=list_get(args, 0), + this=seq_get(args, 0), format=Dialect[dialect].format_time( - list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) ), ) @@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression): "expressions", [e for e in schema.expressions if e not in partitions], ) - prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) + prop.replace( + exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) + ) expression.set("this", schema) return self.create_sql(expression) @@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression): def parse_date_delta(exp_class, unit_mapping=None): def inner_func(args): unit_based = len(args) == 3 - this = list_get(args, 2) if unit_based else list_get(args, 0) - expression = list_get(args, 1) if unit_based else list_get(args, 1) - unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") + this = seq_get(args, 2) if unit_based else seq_get(args, 0) + expression = seq_get(args, 1) if unit_based else seq_get(args, 1) + unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit return exp_class(this=this, expression=expression, unit=unit) -- cgit v1.2.3