summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py201
1 files changed, 101 insertions, 100 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 4958bc6..f5d523b 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -25,6 +25,8 @@ class Dialects(str, Enum):
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
+ DATABRICKS = "databricks"
+ DRILL = "drill"
DUCKDB = "duckdb"
HIVE = "hive"
MYSQL = "mysql"
@@ -38,11 +40,9 @@ class Dialects(str, Enum):
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
+ TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
- DATABRICKS = "databricks"
- DRILL = "drill"
- TERADATA = "teradata"
class _Dialect(type):
@@ -76,16 +76,19 @@ class _Dialect(type):
enum = Dialects.__members__.get(clsname.upper())
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
- klass.time_trie = new_trie(klass.time_mapping)
- klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
- klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
+ klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
+ klass.FORMAT_TRIE = (
+ new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
+ )
+ klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
+ klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
- klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
- klass.identifier_start, klass.identifier_end = list(
+ klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
+ klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
@@ -99,43 +102,80 @@ class _Dialect(type):
(None, None),
)
- klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
- klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
- klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
- klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
+ klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
+ klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
+ klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
+ klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING)
- klass.tokenizer_class.identifiers_can_start_with_digit = (
- klass.identifiers_can_start_with_digit
- )
+ dialect_properties = {
+ **{
+ k: v
+ for k, v in vars(klass).items()
+ if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
+ },
+ "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
+ "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
+ }
+
+ # Pass required dialect properties to the tokenizer, parser and generator classes
+ for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
+ for name, value in dialect_properties.items():
+ if hasattr(subclass, name):
+ setattr(subclass, name, value)
+
+ if not klass.STRICT_STRING_CONCAT:
+ klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
return klass
class Dialect(metaclass=_Dialect):
- index_offset = 0
- unnest_column_only = False
- alias_post_tablesample = False
- identifiers_can_start_with_digit = False
- normalize_functions: t.Optional[str] = "upper"
- null_ordering = "nulls_are_small"
-
- date_format = "'%Y-%m-%d'"
- dateint_format = "'%Y%m%d'"
- time_format = "'%Y-%m-%d %H:%M:%S'"
- time_mapping: t.Dict[str, str] = {}
-
- # autofilled
- quote_start = None
- quote_end = None
- identifier_start = None
- identifier_end = None
-
- time_trie = None
- inverse_time_mapping = None
- inverse_time_trie = None
- tokenizer_class = None
- parser_class = None
- generator_class = None
+ # Determines the base index offset for arrays
+ INDEX_OFFSET = 0
+
+ # If true unnest table aliases are considered only as column aliases
+ UNNEST_COLUMN_ONLY = False
+
+ # Determines whether or not the table alias comes after tablesample
+ ALIAS_POST_TABLESAMPLE = False
+
+ # Determines whether or not an unquoted identifier can start with a digit
+ IDENTIFIERS_CAN_START_WITH_DIGIT = False
+
+ # Determines whether or not CONCAT's arguments must be strings
+ STRICT_STRING_CONCAT = False
+
+ # Determines how function names are going to be normalized
+ NORMALIZE_FUNCTIONS: bool | str = "upper"
+
+ # Indicates the default null ordering method to use if not explicitly set
+ # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
+ NULL_ORDERING = "nulls_are_small"
+
+ DATE_FORMAT = "'%Y-%m-%d'"
+ DATEINT_FORMAT = "'%Y%m%d'"
+ TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
+
+ # Custom time mappings in which the key represents dialect time format
+ # and the value represents a python time format
+ TIME_MAPPING: t.Dict[str, str] = {}
+
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
+ # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
+ # special syntax cast(x as date format 'yyyy') defaults to time_mapping
+ FORMAT_MAPPING: t.Dict[str, str] = {}
+
+ # Autofilled
+ tokenizer_class = Tokenizer
+ parser_class = Parser
+ generator_class = Generator
+
+ # A trie of the time_mapping keys
+ TIME_TRIE: t.Dict = {}
+ FORMAT_TRIE: t.Dict = {}
+
+ INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
+ INVERSE_TIME_TRIE: t.Dict = {}
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
@@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect):
) -> t.Optional[exp.Expression]:
if isinstance(expression, str):
return exp.Literal.string(
- format_time(
- expression[1:-1], # the time formats are quoted
- cls.time_mapping,
- cls.time_trie,
- )
+ # the time formats are quoted
+ format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
)
+
if expression and expression.is_string:
- return exp.Literal.string(
- format_time(
- expression.this,
- cls.time_mapping,
- cls.time_trie,
- )
- )
+ return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
+
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
@@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
- self._tokenizer = self.tokenizer_class() # type: ignore
+ self._tokenizer = self.tokenizer_class()
return self._tokenizer
def parser(self, **opts) -> Parser:
- return self.parser_class( # type: ignore
- **{
- "index_offset": self.index_offset,
- "unnest_column_only": self.unnest_column_only,
- "alias_post_tablesample": self.alias_post_tablesample,
- "null_ordering": self.null_ordering,
- **opts,
- },
- )
+ return self.parser_class(**opts)
def generator(self, **opts) -> Generator:
- return self.generator_class( # type: ignore
- **{
- "quote_start": self.quote_start,
- "quote_end": self.quote_end,
- "bit_start": self.bit_start,
- "bit_end": self.bit_end,
- "hex_start": self.hex_start,
- "hex_end": self.hex_end,
- "byte_start": self.byte_start,
- "byte_end": self.byte_end,
- "raw_start": self.raw_start,
- "raw_end": self.raw_end,
- "identifier_start": self.identifier_start,
- "identifier_end": self.identifier_end,
- "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
- "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
- "index_offset": self.index_offset,
- "time_mapping": self.inverse_time_mapping,
- "time_trie": self.inverse_time_trie,
- "unnest_column_only": self.unnest_column_only,
- "alias_post_tablesample": self.alias_post_tablesample,
- "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
- "normalize_functions": self.normalize_functions,
- "null_ordering": self.null_ordering,
- **opts,
- }
- )
+ return self.generator_class(**opts)
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
@@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
- exp.Like(
- this=exp.Lower(this=expression.this),
- expression=expression.args["expression"],
- )
+ exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
)
@@ -359,6 +355,7 @@ def var_map_sql(
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
+
return self.func(map_func_name, *args)
@@ -381,7 +378,7 @@ def format_time_lambda(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
seq_get(args, 1)
- or (Dialect[dialect].time_format if default is True else default or None)
+ or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
),
)
@@ -437,9 +434,7 @@ def parse_date_delta_with_interval(
expression = exp.Literal.number(expression.this)
return expression_class(
- this=args[0],
- expression=expression,
- unit=exp.Literal.string(interval.text("unit")),
+ this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
return func
@@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
+ this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
)
@@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
- if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
+ if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
+def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
+ this, *rest_args = expression.expressions
+ for arg in rest_args:
+ this = exp.DPipe(this=this, expression=arg)
+
+ return self.sql(this)
+
+
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []