summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py52
1 files changed, 35 insertions, 17 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 33985a7..3af08bb 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
+import typing as t
from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
-from sqlglot.helper import flatten, list_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@@ -32,7 +35,7 @@ class Dialects(str, Enum):
class _Dialect(type):
- classes = {}
+ classes: t.Dict[str, Dialect] = {}
@classmethod
def __getitem__(cls, key):
@@ -56,19 +59,30 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
- klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
-
- if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
+ klass.identifier_start, klass.identifier_end = list(
+ klass.tokenizer_class._IDENTIFIERS.items()
+ )[0]
+
+ if (
+ klass.tokenizer_class._BIT_STRINGS
+ and exp.BitString not in klass.generator_class.TRANSFORMS
+ ):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
- if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
+ if (
+ klass.tokenizer_class._HEX_STRINGS
+ and exp.HexString not in klass.generator_class.TRANSFORMS
+ ):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
- if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
+ if (
+ klass.tokenizer_class._BYTE_STRINGS
+ and exp.ByteString not in klass.generator_class.TRANSFORMS
+ ):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
@@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
- normalize_functions = "upper"
+ normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
- time_mapping = {}
+ time_mapping: t.Dict[str, str] = {}
# autofilled
quote_start = None
@@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
- "escape": self.tokenizer_class.ESCAPE,
+ "escape": self.tokenizer_class.ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
def if_sql(self, expression):
- expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
+ expressions = self.format_args(
+ expression.this, expression.args.get("true"), expression.args.get("false")
+ )
return f"IF({expressions})"
@@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
def _format_time(args):
return exp_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
format=Dialect[dialect].format_time(
- list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
+ seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
@@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
"expressions",
[e for e in schema.expressions if e not in partitions],
)
- prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
+ prop.replace(
+ exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
+ )
expression.set("this", schema)
return self.create_sql(expression)
@@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
unit_based = len(args) == 3
- this = list_get(args, 2) if unit_based else list_get(args, 0)
- expression = list_get(args, 1) if unit_based else list_get(args, 1)
- unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
+ this = seq_get(args, 2) if unit_based else seq_get(args, 0)
+ expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
+ unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit)