summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:35 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:35 +0000
commitd1f00706bff58b863b0a1c5bf4adf39d36049d4c (patch)
tree3a8ecc5d1509d655d5df6b1455bc1e309da2c02c /sqlglot/dialects
parentReleasing debian version 9.0.6-1. (diff)
downloadsqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.tar.xz
sqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.zip
Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py57
-rw-r--r--sqlglot/dialects/clickhouse.py24
-rw-r--r--sqlglot/dialects/databricks.py4
-rw-r--r--sqlglot/dialects/dialect.py52
-rw-r--r--sqlglot/dialects/duckdb.py33
-rw-r--r--sqlglot/dialects/hive.py57
-rw-r--r--sqlglot/dialects/mysql.py329
-rw-r--r--sqlglot/dialects/oracle.py20
-rw-r--r--sqlglot/dialects/postgres.py25
-rw-r--r--sqlglot/dialects/presto.py41
-rw-r--r--sqlglot/dialects/redshift.py13
-rw-r--r--sqlglot/dialects/snowflake.py46
-rw-r--r--sqlglot/dialects/spark.py37
-rw-r--r--sqlglot/dialects/sqlite.py24
-rw-r--r--sqlglot/dialects/starrocks.py7
-rw-r--r--sqlglot/dialects/tableau.py14
-rw-r--r--sqlglot/dialects/trino.py4
-rw-r--r--sqlglot/dialects/tsql.py54
18 files changed, 596 insertions, 245 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 62d042e..5bbff9d 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,21 +1,21 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
no_ilike_sql,
rename_func,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _date_add(expression_class):
def func(args):
- interval = list_get(args, 1)
+ interval = seq_get(args, 1)
return expression_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
@@ -23,6 +23,13 @@ def _date_add(expression_class):
return func
+def _date_trunc(args):
+ unit = seq_get(args, 1)
+ if isinstance(unit, exp.Column):
+ unit = exp.Var(this=unit.name)
+ return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
+
+
def _date_add_sql(data_type, kind):
def func(self, expression):
this = self.sql(expression, "this")
@@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
structs = []
for row in rows:
aliases = [
- exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
+ exp.alias_(value, column_name)
+ for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
@@ -89,18 +97,19 @@ class BigQuery(Dialect):
"%j": "%-j",
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"]
]
+ COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
@@ -111,35 +120,40 @@ class BigQuery(Dialect):
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
}
+ KEYWORDS.pop("DIV")
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
+ "DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
+ "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
- "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
+ "PARSE_TIMESTAMP": lambda args: exp.StrToTime(
+ this=seq_get(args, 1), format=seq_get(args, 0)
+ ),
}
NO_PAREN_FUNCTIONS = {
- **Parser.NO_PAREN_FUNCTIONS,
+ **parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
- *Parser.NESTED_TYPE_TOKENS,
+ *parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
@@ -148,6 +162,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
+ exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@@ -157,11 +172,13 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
- exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
+ exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
+ if e.name == "IMMUTABLE"
+ else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index f446e6d..332b4c1 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -1,8 +1,9 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser, parse_var_map
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.parser import parse_var_map
+from sqlglot.tokens import TokenType
def _lower_func(sql):
@@ -14,11 +15,12 @@ class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
+ COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
@@ -30,9 +32,9 @@ class ClickHouse(Dialect):
"TUPLE": TokenType.STRUCT,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"MAP": parse_var_map,
}
@@ -44,11 +46,11 @@ class ClickHouse(Dialect):
return this
- class Generator(Generator):
+ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@@ -63,7 +65,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 9dc3c38..2498c62 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
@@ -15,7 +17,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
- **Spark.Generator.TRANSFORMS,
+ **Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 33985a7..3af08bb 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
+import typing as t
from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
-from sqlglot.helper import flatten, list_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@@ -32,7 +35,7 @@ class Dialects(str, Enum):
class _Dialect(type):
- classes = {}
+ classes: t.Dict[str, Dialect] = {}
@classmethod
def __getitem__(cls, key):
@@ -56,19 +59,30 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
- klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
-
- if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
+ klass.identifier_start, klass.identifier_end = list(
+ klass.tokenizer_class._IDENTIFIERS.items()
+ )[0]
+
+ if (
+ klass.tokenizer_class._BIT_STRINGS
+ and exp.BitString not in klass.generator_class.TRANSFORMS
+ ):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
- if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
+ if (
+ klass.tokenizer_class._HEX_STRINGS
+ and exp.HexString not in klass.generator_class.TRANSFORMS
+ ):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
- if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
+ if (
+ klass.tokenizer_class._BYTE_STRINGS
+ and exp.ByteString not in klass.generator_class.TRANSFORMS
+ ):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
@@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
- normalize_functions = "upper"
+ normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
- time_mapping = {}
+ time_mapping: t.Dict[str, str] = {}
# autofilled
quote_start = None
@@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
- "escape": self.tokenizer_class.ESCAPE,
+ "escape": self.tokenizer_class.ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
def if_sql(self, expression):
- expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
+ expressions = self.format_args(
+ expression.this, expression.args.get("true"), expression.args.get("false")
+ )
return f"IF({expressions})"
@@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
def _format_time(args):
return exp_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
format=Dialect[dialect].format_time(
- list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
+ seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
@@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
"expressions",
[e for e in schema.expressions if e not in partitions],
)
- prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
+ prop.replace(
+ exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
+ )
expression.set("this", schema)
return self.create_sql(expression)
@@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
unit_based = len(args) == 3
- this = list_get(args, 2) if unit_based else list_get(args, 0)
- expression = list_get(args, 1) if unit_based else list_get(args, 1)
- unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
+ this = seq_get(args, 2) if unit_based else seq_get(args, 0)
+ expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
+ unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit)
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f3ff6d3..781edff 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _unix_to_time(self, expression):
@@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
def _sort_array_reverse(args):
- return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
+ return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
def _struct_pack_sql(self, expression):
- args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
+ args = [
+ self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
+ for e in expression.expressions
+ ]
return f"STRUCT_PACK({', '.join(args)})"
@@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
class DuckDB(Dialect):
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@@ -92,7 +95,7 @@ class DuckDB(Dialect):
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=exp.Literal.number(1000),
)
),
@@ -112,11 +115,11 @@ class DuckDB(Dialect):
"UNNEST": exp.Explode.from_arg_list,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@@ -160,7 +163,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 03049ff..ed7357c 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -1,4 +1,6 @@
-from sqlglot import exp, transforms
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
var_map_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser, parse_var_map
-from sqlglot.tokens import Tokenizer
+from sqlglot.helper import seq_get
+from sqlglot.parser import parse_var_map
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
- int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
+ int(expression.text("expression")) * multiplier
+ if expression.expression.is_number
+ else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
@@ -165,10 +167,10 @@ class Hive(Dialect):
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
ENCODE = "utf-8"
NUMERIC_LITERALS = {
@@ -180,40 +182,44 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
- class Parser(Parser):
+ class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
unit=exp.Literal.string("DAY"),
),
"DATEDIFF": lambda args: exp.DateDiff(
- this=exp.TsOrDsToDate(this=list_get(args, 0)),
- expression=exp.TsOrDsToDate(this=list_get(args, 1)),
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=exp.Mul(
- this=list_get(args, 1),
+ this=seq_get(args, 1),
expression=exp.Literal.number(-1),
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
- "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
- this=list_get(args, 1),
- substr=list_get(args, 0),
- position=list_get(args, 2),
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ ),
+ "LOG": (
+ lambda args: exp.Log.from_arg_list(args)
+ if len(args) > 1
+ else exp.Ln.from_arg_list(args)
),
- "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@@ -226,15 +232,16 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
- **transforms.UNALIAS_GROUP,
+ **generator.Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 524390f..e742640 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,4 +1,8 @@
-from sqlglot import exp
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
no_ilike_sql,
@@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
+
+
+def _show_parser(*args, **kwargs):
+ def _parse(self):
+ return self._parse_show_mysql(*args, **kwargs)
+
+ return _parse
def _date_trunc_sql(self, expression):
- unit = expression.text("unit").lower()
+ unit = expression.name.lower()
- this = self.sql(expression.this)
+ expr = self.sql(expression.expression)
if unit == "day":
- return f"DATE({this})"
+ return f"DATE({expr})"
if unit == "week":
- concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
+ concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
elif unit == "month":
- concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
+ concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
elif unit == "quarter":
- concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
+ concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
elif unit == "year":
- concat = f"CONCAT(YEAR({this}), ' 1 1')"
+ concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
self.unsupported("Unexpected interval unit: {unit}")
- return f"DATE({this})"
+ return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args):
- date_format = MySQL.format_time(list_get(args, 1))
- return exp.StrToDate(this=list_get(args, 0), format=date_format)
+ date_format = MySQL.format_time(seq_get(args, 1))
+ return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression):
@@ -66,9 +75,9 @@ def _trim_sql(self, expression):
def _date_add(expression_class):
def func(args):
- interval = list_get(args, 1)
+ interval = seq_get(args, 1)
return expression_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
@@ -101,15 +110,16 @@ class MySQL(Dialect):
"%l": "%-I",
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
+ ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@@ -156,20 +166,23 @@ class MySQL(Dialect):
"_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER,
+ "@@": TokenType.SESSION_PARAMETER,
}
- class Parser(Parser):
+ COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
+
+ class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -178,15 +191,212 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
- **Parser.PROPERTY_PARSERS,
+ **parser.Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
- class Generator(Generator):
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.SHOW: lambda self: self._parse_show(),
+ TokenType.SET: lambda self: self._parse_set(),
+ }
+
+ SHOW_PARSERS = {
+ "BINARY LOGS": _show_parser("BINARY LOGS"),
+ "MASTER LOGS": _show_parser("BINARY LOGS"),
+ "BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
+ "CHARACTER SET": _show_parser("CHARACTER SET"),
+ "CHARSET": _show_parser("CHARACTER SET"),
+ "COLLATION": _show_parser("COLLATION"),
+ "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
+ "COLUMNS": _show_parser("COLUMNS", target="FROM"),
+ "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
+ "CREATE EVENT": _show_parser("CREATE EVENT", target=True),
+ "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
+ "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
+ "CREATE TABLE": _show_parser("CREATE TABLE", target=True),
+ "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
+ "CREATE VIEW": _show_parser("CREATE VIEW", target=True),
+ "DATABASES": _show_parser("DATABASES"),
+ "ENGINE": _show_parser("ENGINE", target=True),
+ "STORAGE ENGINES": _show_parser("ENGINES"),
+ "ENGINES": _show_parser("ENGINES"),
+ "ERRORS": _show_parser("ERRORS"),
+ "EVENTS": _show_parser("EVENTS"),
+ "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
+ "FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
+ "GRANTS": _show_parser("GRANTS", target="FOR"),
+ "INDEX": _show_parser("INDEX", target="FROM"),
+ "MASTER STATUS": _show_parser("MASTER STATUS"),
+ "OPEN TABLES": _show_parser("OPEN TABLES"),
+ "PLUGINS": _show_parser("PLUGINS"),
+ "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
+ "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
+ "PRIVILEGES": _show_parser("PRIVILEGES"),
+ "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
+ "PROCESSLIST": _show_parser("PROCESSLIST"),
+ "PROFILE": _show_parser("PROFILE"),
+ "PROFILES": _show_parser("PROFILES"),
+ "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
+ "REPLICAS": _show_parser("REPLICAS"),
+ "SLAVE HOSTS": _show_parser("REPLICAS"),
+ "REPLICA STATUS": _show_parser("REPLICA STATUS"),
+ "SLAVE STATUS": _show_parser("REPLICA STATUS"),
+ "GLOBAL STATUS": _show_parser("STATUS", global_=True),
+ "SESSION STATUS": _show_parser("STATUS"),
+ "STATUS": _show_parser("STATUS"),
+ "TABLE STATUS": _show_parser("TABLE STATUS"),
+ "FULL TABLES": _show_parser("TABLES", full=True),
+ "TABLES": _show_parser("TABLES"),
+ "TRIGGERS": _show_parser("TRIGGERS"),
+ "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
+ "SESSION VARIABLES": _show_parser("VARIABLES"),
+ "VARIABLES": _show_parser("VARIABLES"),
+ "WARNINGS": _show_parser("WARNINGS"),
+ }
+
+ SET_PARSERS = {
+ "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
+ "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
+ "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
+ "SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
+ "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
+ "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
+ "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
+ "NAMES": lambda self: self._parse_set_item_names(),
+ }
+
+ PROFILE_TYPES = {
+ "ALL",
+ "BLOCK IO",
+ "CONTEXT SWITCHES",
+ "CPU",
+ "IPC",
+ "MEMORY",
+ "PAGE FAULTS",
+ "SOURCE",
+ "SWAPS",
+ }
+
+ def _parse_show_mysql(self, this, target=False, full=None, global_=None):
+ if target:
+ if isinstance(target, str):
+ self._match_text(target)
+ target_id = self._parse_id_var()
+ else:
+ target_id = None
+
+ log = self._parse_string() if self._match_text("IN") else None
+
+ if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
+ position = self._parse_number() if self._match_text("FROM") else None
+ db = None
+ else:
+ position = None
+ db = self._parse_id_var() if self._match_text("FROM") else None
+
+ channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
+
+ like = self._parse_string() if self._match_text("LIKE") else None
+ where = self._parse_where()
+
+ if this == "PROFILE":
+ types = self._parse_csv(self._parse_show_profile_type)
+ query = self._parse_number() if self._match_text("FOR", "QUERY") else None
+ offset = self._parse_number() if self._match_text("OFFSET") else None
+ limit = self._parse_number() if self._match_text("LIMIT") else None
+ else:
+ types, query = None, None
+ offset, limit = self._parse_oldstyle_limit()
+
+ mutex = True if self._match_text("MUTEX") else None
+ mutex = False if self._match_text("STATUS") else mutex
+
+ return self.expression(
+ exp.Show,
+ this=this,
+ target=target_id,
+ full=full,
+ log=log,
+ position=position,
+ db=db,
+ channel=channel,
+ like=like,
+ where=where,
+ types=types,
+ query=query,
+ offset=offset,
+ limit=limit,
+ mutex=mutex,
+ **{"global": global_},
+ )
+
+ def _parse_show_profile_type(self):
+ for type_ in self.PROFILE_TYPES:
+ if self._match_text(*type_.split(" ")):
+ return exp.Var(this=type_)
+ return None
+
+ def _parse_oldstyle_limit(self):
+ limit = None
+ offset = None
+ if self._match_text("LIMIT"):
+ parts = self._parse_csv(self._parse_number)
+ if len(parts) == 1:
+ limit = parts[0]
+ elif len(parts) == 2:
+ limit = parts[1]
+ offset = parts[0]
+ return offset, limit
+
+ def _default_parse_set_item(self):
+ return self._parse_set_item_assignment(kind=None)
+
+ def _parse_set_item_assignment(self, kind):
+ left = self._parse_primary() or self._parse_id_var()
+ if not self._match(TokenType.EQ):
+ self.raise_error("Expected =")
+ right = self._parse_statement() or self._parse_id_var()
+
+ this = self.expression(
+ exp.EQ,
+ this=left,
+ expression=right,
+ )
+
+ return self.expression(
+ exp.SetItem,
+ this=this,
+ kind=kind,
+ )
+
+ def _parse_set_item_charset(self, kind):
+ this = self._parse_string() or self._parse_id_var()
+
+ return self.expression(
+ exp.SetItem,
+ this=this,
+ kind=kind,
+ )
+
+ def _parse_set_item_names(self):
+ charset = self._parse_string() or self._parse_id_var()
+ if self._match_text("COLLATE"):
+ collate = self._parse_string() or self._parse_id_var()
+ else:
+ collate = None
+ return self.expression(
+ exp.SetItem,
+ this=charset,
+ collate=collate,
+ kind="NAMES",
+ )
+
+ class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
@@ -199,6 +409,8 @@ class MySQL(Dialect):
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
+ exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
+ exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
}
ROOT_PROPERTIES = {
@@ -209,4 +421,69 @@ class MySQL(Dialect):
exp.SchemaCommentProperty,
}
- WITH_PROPERTIES = {}
+ WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
+
+ def show_sql(self, expression):
+ this = f" {expression.name}"
+ full = " FULL" if expression.args.get("full") else ""
+ global_ = " GLOBAL" if expression.args.get("global") else ""
+
+ target = self.sql(expression, "target")
+ target = f" {target}" if target else ""
+ if expression.name in {"COLUMNS", "INDEX"}:
+ target = f" FROM{target}"
+ elif expression.name == "GRANTS":
+ target = f" FOR{target}"
+
+ db = self._prefixed_sql("FROM", expression, "db")
+
+ like = self._prefixed_sql("LIKE", expression, "like")
+ where = self.sql(expression, "where")
+
+ types = self.expressions(expression, key="types")
+ types = f" {types}" if types else types
+ query = self._prefixed_sql("FOR QUERY", expression, "query")
+
+ if expression.name == "PROFILE":
+ offset = self._prefixed_sql("OFFSET", expression, "offset")
+ limit = self._prefixed_sql("LIMIT", expression, "limit")
+ else:
+ offset = ""
+ limit = self._oldstyle_limit_sql(expression)
+
+ log = self._prefixed_sql("IN", expression, "log")
+ position = self._prefixed_sql("FROM", expression, "position")
+
+ channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
+
+ if expression.name == "ENGINE":
+ mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
+ else:
+ mutex_or_status = ""
+
+ return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
+
+ def _prefixed_sql(self, prefix, expression, arg):
+ sql = self.sql(expression, arg)
+ if not sql:
+ return ""
+ return f" {prefix} {sql}"
+
+ def _oldstyle_limit_sql(self, expression):
+ limit = self.sql(expression, "limit")
+ offset = self.sql(expression, "offset")
+ if limit:
+ limit_offset = f"{offset}, {limit}" if offset else limit
+ return f" LIMIT {limit_offset}"
+ return ""
+
+ def setitem_sql(self, expression):
+ kind = self.sql(expression, "kind")
+ kind = f"{kind} " if kind else ""
+ this = self.sql(expression, "this")
+ collate = self.sql(expression, "collate")
+ collate = f" COLLATE {collate}" if collate else ""
+ return f"{kind}{this}{collate}"
+
+ def set_sql(self, expression):
+ return f"SET {self.expressions(expression)}"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 144dba5..3bc1109 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -1,8 +1,9 @@
-from sqlglot import exp, transforms
+from __future__ import annotations
+
+from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
-from sqlglot.generator import Generator
from sqlglot.helper import csv
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
def _limit_sql(self, expression):
@@ -36,9 +37,9 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@@ -49,11 +50,12 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.BINARY: "BLOB",
+ exp.DataType.Type.VARBINARY: "BLOB",
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
- **transforms.UNALIAS_GROUP,
+ **generator.Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@@ -86,9 +88,9 @@ class Oracle(Dialect):
def table_sql(self, expression):
return super().table_sql(expression, sep=" ")
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 459e926..553a73b 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
@@ -160,12 +160,12 @@ class Postgres(Dialect):
"YYYY": "%Y", # 2015
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
@@ -179,31 +179,32 @@ class Postgres(Dialect):
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
- **Tokenizer.SINGLE_TOKENS,
+ **tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
+ exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index a2d392c..11ea778 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -1,4 +1,6 @@
-from sqlglot import exp, transforms
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
@@ -110,30 +110,29 @@ class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'"
- time_mapping = MySQL.time_mapping
+ time_mapping = MySQL.time_mapping # type: ignore
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
- **Tokenizer.KEYWORDS,
- "VARBINARY": TokenType.BINARY,
+ **tokens.Tokenizer.KEYWORDS,
"ROW": TokenType.STRUCT,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
- this=list_get(args, 2),
- expression=list_get(args, 1),
- unit=list_get(args, 0),
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
),
"DATE_DIFF": lambda args: exp.DateDiff(
- this=list_get(args, 2),
- expression=list_get(args, 1),
- unit=list_get(args, 0),
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@@ -143,7 +142,7 @@ class Presto(Dialect):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@@ -159,7 +158,7 @@ class Presto(Dialect):
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@@ -169,8 +168,8 @@ class Presto(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
- **transforms.UNALIAS_GROUP,
+ **generator.Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index e1f7b78..a9b12fb 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
- **Postgres.time_mapping,
+ **Postgres.time_mapping, # type: ignore
"MON": "%b",
"HH": "%H",
}
class Tokenizer(Postgres.Tokenizer):
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
KEYWORDS = {
- **Postgres.Tokenizer.KEYWORDS,
+ **Postgres.Tokenizer.KEYWORDS, # type: ignore
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
- "VARBYTE": TokenType.BINARY,
+ "VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
TYPE_MAPPING = {
- **Postgres.Generator.TYPE_MAPPING,
+ **Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
+ exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 3b97e6d..d1aaded 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
rename_func,
)
from sqlglot.expressions import Literal
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _check_int(s):
@@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
- raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
+ raise ValueError(
+ f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
+ )
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
@@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime(this=first_arg, scale=timescale)
- first_arg = list_get(args, 0)
+ first_arg = seq_get(args, 0)
if not isinstance(first_arg, Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
-def _unix_to_time(self, expression):
+def _unix_to_time_sql(self, expression):
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -132,9 +134,9 @@ class Snowflake(Dialect):
"ff6": "%f",
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
@@ -143,18 +145,18 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
FUNC_TOKENS = {
- *Parser.FUNC_TOKENS,
+ *parser.Parser.FUNC_TOKENS,
TokenType.RLIKE,
TokenType.TABLE,
}
COLUMN_OPERATORS = {
- **Parser.COLUMN_OPERATORS,
+ **parser.Parser.COLUMN_OPERATORS, # type: ignore
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@@ -163,21 +165,21 @@ class Snowflake(Dialect):
}
PROPERTY_PARSERS = {
- **Parser.PROPERTY_PARSERS,
+ **parser.Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
SINGLE_TOKENS = {
- **Tokenizer.SINGLE_TOKENS,
+ **tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
@@ -187,15 +189,15 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
CREATE_TRANSIENT = True
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time,
+ exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
@@ -204,7 +206,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 572f411..4e404b8 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -1,8 +1,9 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
+from sqlglot.helper import seq_get
def _create_sql(self, e):
@@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
class Spark(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS,
+ **Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
start=exp.Literal.number(1),
- length=list_get(args, 1),
+ length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
start=exp.Sub(
- this=exp.Length(this=list_get(args, 0)),
- expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
+ this=exp.Length(this=seq_get(args, 0)),
+ expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
- length=list_get(args, 1),
+ length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -88,14 +89,14 @@ class Spark(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING,
+ **Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
TRANSFORMS = {
- **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
+ **Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
@@ -114,6 +115,8 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
}
+ TRANSFORMS.pop(exp.ArraySort)
+ TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 62b7617..8c9fb76 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
)
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
class SQLite(Dialect):
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
- "VARBINARY": TokenType.BINARY,
+ **tokens.Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@@ -46,6 +45,7 @@ class SQLite(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.BINARY: "BLOB",
+ exp.DataType.Type.VARBINARY: "BLOB",
}
TOKEN_MAPPING = {
@@ -53,7 +53,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 0cba6fe..3519c09 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -1,10 +1,12 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
- class Generator(MySQL.Generator):
+ class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
@@ -13,7 +15,7 @@ class StarRocks(MySQL):
}
TRANSFORMS = {
- **MySQL.Generator.TRANSFORMS,
+ **MySQL.Generator.TRANSFORMS, # type: ignore
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
@@ -22,3 +24,4 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
+ TRANSFORMS.pop(exp.DateTrunc)
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 45aa041..63e7275 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -1,7 +1,7 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser
def _if_sql(self, expression):
@@ -20,17 +20,17 @@ def _count_sql(self, expression):
class Tableau(Dialect):
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 9a6f7fe..c7b34fe 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.presto import Presto
@@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
- **Presto.Generator.TRANSFORMS,
+ **Presto.Generator.TRANSFORMS, # type: ignore
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 0f93c75..a233d4b 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,15 +1,22 @@
+from __future__ import annotations
+
import re
-from sqlglot import exp
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
from sqlglot.expressions import DataType
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
+from sqlglot.helper import seq_get
from sqlglot.time import format_time
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
-FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
+FULL_FORMAT_TIME_MAPPING = {
+ "weekday": "%A",
+ "dw": "%A",
+ "w": "%A",
+ "month": "%B",
+ "mm": "%B",
+ "m": "%B",
+}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
- this=list_get(args, 1),
+ this=seq_get(args, 1),
format=exp.Literal.string(
format_time(
- list_get(args, 0).name or (TSQL.time_format if default is True else default),
- {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
+ seq_get(args, 0).name or (TSQL.time_format if default is True else default),
+ {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
+ if full_format_mapping
+ else TSQL.time_mapping,
)
),
)
@@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def parse_format(args):
- fmt = list_get(args, 1)
+ fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
- return exp.NumberToStr(this=list_get(args, 0), format=fmt)
+ return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
return exp.TimeToStr(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@@ -188,11 +197,11 @@ class TSQL(Dialect):
"Y": "%a %Y",
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
@@ -200,7 +209,6 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
- "VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
@@ -213,9 +221,9 @@ class TSQL(Dialect):
"TOP": TokenType.TOP,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@@ -243,14 +251,16 @@ class TSQL(Dialect):
this = self._parse_column()
# Retrieve length of datatype and override to default if not specified
- if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
+ if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
format_val = self._parse_number().name
if format_val not in TSQL.convert_format_mapping:
- raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
+ raise ValueError(
+ f"CONVERT function at T-SQL does not support format style {format_val}"
+ )
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
# Check whether the convert entails a string to date format
@@ -272,9 +282,9 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
@@ -283,7 +293,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),