summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:18 +0000
commit67578a7602a5be7eb51f324086c8d49bcf8b7498 (patch)
tree0b7515c922d1c383cea24af5175379cfc8edfd15 /sqlglot
parentReleasing debian version 15.2.0-1. (diff)
downloadsqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.tar.xz
sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.zip
Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/bigquery.py65
-rw-r--r--sqlglot/dialects/clickhouse.py38
-rw-r--r--sqlglot/dialects/dialect.py201
-rw-r--r--sqlglot/dialects/drill.py34
-rw-r--r--sqlglot/dialects/duckdb.py35
-rw-r--r--sqlglot/dialects/hive.py40
-rw-r--r--sqlglot/dialects/mysql.py26
-rw-r--r--sqlglot/dialects/oracle.py17
-rw-r--r--sqlglot/dialects/postgres.py13
-rw-r--r--sqlglot/dialects/presto.py64
-rw-r--r--sqlglot/dialects/redshift.py14
-rw-r--r--sqlglot/dialects/snowflake.py19
-rw-r--r--sqlglot/dialects/spark2.py10
-rw-r--r--sqlglot/dialects/sqlite.py7
-rw-r--r--sqlglot/dialects/tableau.py6
-rw-r--r--sqlglot/dialects/teradata.py44
-rw-r--r--sqlglot/dialects/tsql.py34
-rw-r--r--sqlglot/executor/env.py2
-rw-r--r--sqlglot/executor/python.py2
-rw-r--r--sqlglot/expressions.py90
-rw-r--r--sqlglot/generator.py383
-rw-r--r--sqlglot/helper.py28
-rw-r--r--sqlglot/optimizer/annotate_types.py516
-rw-r--r--sqlglot/optimizer/canonicalize.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py4
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py2
-rw-r--r--sqlglot/optimizer/merge_subqueries.py9
-rw-r--r--sqlglot/optimizer/optimize_joins.py33
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py8
-rw-r--r--sqlglot/optimizer/qualify_columns.py6
-rw-r--r--sqlglot/optimizer/qualify_tables.py6
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/parser.py682
-rw-r--r--sqlglot/planner.py2
-rw-r--r--sqlglot/schema.py2
-rw-r--r--sqlglot/tokens.py40
37 files changed, 1304 insertions, 1184 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5b10852..2166e65 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -7,6 +7,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
+ format_time_lambda,
inline_array_sql,
max_or_greatest,
min_or_least,
@@ -103,16 +104,26 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
class BigQuery(Dialect):
- unnest_column_only = True
- time_mapping = {
- "%M": "%-M",
- "%d": "%-d",
- "%m": "%-m",
- "%y": "%-y",
- "%H": "%-H",
- "%I": "%-I",
- "%S": "%-S",
- "%j": "%-j",
+ UNNEST_COLUMN_ONLY = True
+
+ TIME_MAPPING = {
+ "%D": "%m/%d/%y",
+ }
+
+ FORMAT_MAPPING = {
+ "DD": "%d",
+ "MM": "%m",
+ "MON": "%b",
+ "MONTH": "%B",
+ "YYYY": "%Y",
+ "YY": "%y",
+ "HH": "%I",
+ "HH12": "%I",
+ "HH24": "%H",
+ "MI": "%M",
+ "SS": "%S",
+ "SSSSS": "%f",
+ "TZH": "%z",
}
class Tokenizer(tokens.Tokenizer):
@@ -142,6 +153,7 @@ class BigQuery(Dialect):
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
"RECORD": TokenType.STRUCT,
+ "TIMESTAMP": TokenType.TIMESTAMPTZ,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@@ -155,13 +167,21 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
+ "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
- "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
+ "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
+ "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
+ [seq_get(args, 1), seq_get(args, 0)]
+ ),
+ "PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")(
+ [seq_get(args, 1), seq_get(args, 0)]
+ ),
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@@ -172,15 +192,15 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
+ "SPLIT": lambda args: exp.Split(
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1) or exp.Literal.string(","),
+ ),
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
- "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
- "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
- "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
+ "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
"TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
- "PARSE_TIMESTAMP": lambda args: exp.StrToTime(
- this=seq_get(args, 1), format=seq_get(args, 0)
- ),
}
FUNCTION_PARSERS = {
@@ -274,9 +294,18 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.RegexpExtract: lambda self, e: self.func(
+ "REGEXP_EXTRACT",
+ e.this,
+ e.expression,
+ e.args.get("position"),
+ e.args.get("occurrence"),
+ ),
+ exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.Select: transforms.preprocess(
[_unqualify_unnest, transforms.eliminate_distinct_on]
),
+ exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@@ -295,7 +324,6 @@ class BigQuery(Dialect):
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
- exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
}
TYPE_MAPPING = {
@@ -315,6 +343,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index fc48379..cfa9a7e 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -21,8 +21,9 @@ def _lower_func(sql: str) -> str:
class ClickHouse(Dialect):
- normalize_functions = None
- null_ordering = "nulls_are_last"
+ NORMALIZE_FUNCTIONS: bool | str = False
+ NULL_ORDERING = "nulls_are_last"
+ STRICT_STRING_CONCAT = True
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
@@ -163,11 +164,11 @@ class ClickHouse(Dialect):
return this
- def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
+ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
return super()._parse_position(haystack_first=True)
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
- def _parse_cte(self) -> exp.Expression:
+ def _parse_cte(self) -> exp.CTE:
index = self._index
try:
# WITH <identifier> AS <subquery expression>
@@ -187,17 +188,19 @@ class ClickHouse(Dialect):
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
is_global = self._match(TokenType.GLOBAL) and self._prev
kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev
+
if kind_pre:
kind = self._match_set(self.JOIN_KINDS) and self._prev
side = self._match_set(self.JOIN_SIDES) and self._prev
return is_global, side, kind
+
return (
is_global,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
- def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
join = super()._parse_join(skip_join_token)
if join:
@@ -205,9 +208,14 @@ class ClickHouse(Dialect):
return join
def _parse_function(
- self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
+ self,
+ functions: t.Optional[t.Dict[str, t.Callable]] = None,
+ anonymous: bool = False,
+ optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
- func = super()._parse_function(functions, anonymous)
+ func = super()._parse_function(
+ functions=functions, anonymous=anonymous, optional_parens=optional_parens
+ )
if isinstance(func, exp.Anonymous):
params = self._parse_func_params(func)
@@ -227,10 +235,12 @@ class ClickHouse(Dialect):
) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)
+
if self._match(TokenType.L_PAREN):
params = self._parse_csv(self._parse_lambda)
self._match_r_paren(this)
return params
+
return None
def _parse_quantile(self) -> exp.Quantile:
@@ -247,12 +257,12 @@ class ClickHouse(Dialect):
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
- ) -> exp.Expression:
+ ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
return super()._parse_primary_key(
wrapped_optional=wrapped_optional or in_props, in_props=in_props
)
- def _parse_on_property(self) -> t.Optional[exp.Property]:
+ def _parse_on_property(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match_text_seq("CLUSTER"):
this = self._parse_id_var()
@@ -329,6 +339,16 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
+ def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
+ # Clickhouse errors out if we try to cast a NULL value to TEXT
+ return self.func(
+ "CONCAT",
+ *[
+ exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text"))
+ for e in expression.expressions
+ ],
+ )
+
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 4958bc6..f5d523b 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -25,6 +25,8 @@ class Dialects(str, Enum):
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
+ DATABRICKS = "databricks"
+ DRILL = "drill"
DUCKDB = "duckdb"
HIVE = "hive"
MYSQL = "mysql"
@@ -38,11 +40,9 @@ class Dialects(str, Enum):
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
+ TERADATA = "teradata"
TRINO = "trino"
TSQL = "tsql"
- DATABRICKS = "databricks"
- DRILL = "drill"
- TERADATA = "teradata"
class _Dialect(type):
@@ -76,16 +76,19 @@ class _Dialect(type):
enum = Dialects.__members__.get(clsname.upper())
cls.classes[enum.value if enum is not None else clsname.lower()] = klass
- klass.time_trie = new_trie(klass.time_mapping)
- klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
- klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
+ klass.TIME_TRIE = new_trie(klass.TIME_MAPPING)
+ klass.FORMAT_TRIE = (
+ new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE
+ )
+ klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
+ klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
- klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
- klass.identifier_start, klass.identifier_end = list(
+ klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
+ klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
@@ -99,43 +102,80 @@ class _Dialect(type):
(None, None),
)
- klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING)
- klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING)
- klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
- klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
+ klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
+ klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
+ klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
+ klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING)
- klass.tokenizer_class.identifiers_can_start_with_digit = (
- klass.identifiers_can_start_with_digit
- )
+ dialect_properties = {
+ **{
+ k: v
+ for k, v in vars(klass).items()
+ if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
+ },
+ "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
+ "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
+ }
+
+ # Pass required dialect properties to the tokenizer, parser and generator classes
+ for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class):
+ for name, value in dialect_properties.items():
+ if hasattr(subclass, name):
+ setattr(subclass, name, value)
+
+ if not klass.STRICT_STRING_CONCAT:
+ klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe
return klass
class Dialect(metaclass=_Dialect):
- index_offset = 0
- unnest_column_only = False
- alias_post_tablesample = False
- identifiers_can_start_with_digit = False
- normalize_functions: t.Optional[str] = "upper"
- null_ordering = "nulls_are_small"
-
- date_format = "'%Y-%m-%d'"
- dateint_format = "'%Y%m%d'"
- time_format = "'%Y-%m-%d %H:%M:%S'"
- time_mapping: t.Dict[str, str] = {}
-
- # autofilled
- quote_start = None
- quote_end = None
- identifier_start = None
- identifier_end = None
-
- time_trie = None
- inverse_time_mapping = None
- inverse_time_trie = None
- tokenizer_class = None
- parser_class = None
- generator_class = None
+ # Determines the base index offset for arrays
+ INDEX_OFFSET = 0
+
+ # If true unnest table aliases are considered only as column aliases
+ UNNEST_COLUMN_ONLY = False
+
+ # Determines whether or not the table alias comes after tablesample
+ ALIAS_POST_TABLESAMPLE = False
+
+ # Determines whether or not an unquoted identifier can start with a digit
+ IDENTIFIERS_CAN_START_WITH_DIGIT = False
+
+ # Determines whether or not CONCAT's arguments must be strings
+ STRICT_STRING_CONCAT = False
+
+ # Determines how function names are going to be normalized
+ NORMALIZE_FUNCTIONS: bool | str = "upper"
+
+ # Indicates the default null ordering method to use if not explicitly set
+ # Options are: "nulls_are_small", "nulls_are_large", "nulls_are_last"
+ NULL_ORDERING = "nulls_are_small"
+
+ DATE_FORMAT = "'%Y-%m-%d'"
+ DATEINT_FORMAT = "'%Y%m%d'"
+ TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
+
+ # Custom time mappings in which the key represents dialect time format
+ # and the value represents a python time format
+ TIME_MAPPING: t.Dict[str, str] = {}
+
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
+ # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
+ # special syntax cast(x as date format 'yyyy') defaults to time_mapping
+ FORMAT_MAPPING: t.Dict[str, str] = {}
+
+ # Autofilled
+ tokenizer_class = Tokenizer
+ parser_class = Parser
+ generator_class = Generator
+
+ # A trie of the time_mapping keys
+ TIME_TRIE: t.Dict = {}
+ FORMAT_TRIE: t.Dict = {}
+
+ INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
+ INVERSE_TIME_TRIE: t.Dict = {}
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
@@ -164,20 +204,13 @@ class Dialect(metaclass=_Dialect):
) -> t.Optional[exp.Expression]:
if isinstance(expression, str):
return exp.Literal.string(
- format_time(
- expression[1:-1], # the time formats are quoted
- cls.time_mapping,
- cls.time_trie,
- )
+ # the time formats are quoted
+ format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE)
)
+
if expression and expression.is_string:
- return exp.Literal.string(
- format_time(
- expression.this,
- cls.time_mapping,
- cls.time_trie,
- )
- )
+ return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE))
+
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
@@ -200,48 +233,14 @@ class Dialect(metaclass=_Dialect):
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
- self._tokenizer = self.tokenizer_class() # type: ignore
+ self._tokenizer = self.tokenizer_class()
return self._tokenizer
def parser(self, **opts) -> Parser:
- return self.parser_class( # type: ignore
- **{
- "index_offset": self.index_offset,
- "unnest_column_only": self.unnest_column_only,
- "alias_post_tablesample": self.alias_post_tablesample,
- "null_ordering": self.null_ordering,
- **opts,
- },
- )
+ return self.parser_class(**opts)
def generator(self, **opts) -> Generator:
- return self.generator_class( # type: ignore
- **{
- "quote_start": self.quote_start,
- "quote_end": self.quote_end,
- "bit_start": self.bit_start,
- "bit_end": self.bit_end,
- "hex_start": self.hex_start,
- "hex_end": self.hex_end,
- "byte_start": self.byte_start,
- "byte_end": self.byte_end,
- "raw_start": self.raw_start,
- "raw_end": self.raw_end,
- "identifier_start": self.identifier_start,
- "identifier_end": self.identifier_end,
- "string_escape": self.tokenizer_class.STRING_ESCAPES[0],
- "identifier_escape": self.tokenizer_class.IDENTIFIER_ESCAPES[0],
- "index_offset": self.index_offset,
- "time_mapping": self.inverse_time_mapping,
- "time_trie": self.inverse_time_trie,
- "unnest_column_only": self.unnest_column_only,
- "alias_post_tablesample": self.alias_post_tablesample,
- "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
- "normalize_functions": self.normalize_functions,
- "null_ordering": self.null_ordering,
- **opts,
- }
- )
+ return self.generator_class(**opts)
DialectType = t.Union[str, Dialect, t.Type[Dialect], None]
@@ -279,10 +278,7 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
- exp.Like(
- this=exp.Lower(this=expression.this),
- expression=expression.args["expression"],
- )
+ exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
)
@@ -359,6 +355,7 @@ def var_map_sql(
for key, value in zip(keys.expressions, values.expressions):
args.append(self.sql(key))
args.append(self.sql(value))
+
return self.func(map_func_name, *args)
@@ -381,7 +378,7 @@ def format_time_lambda(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
seq_get(args, 1)
- or (Dialect[dialect].time_format if default is True else default or None)
+ or (Dialect[dialect].TIME_FORMAT if default is True else default or None)
),
)
@@ -437,9 +434,7 @@ def parse_date_delta_with_interval(
expression = exp.Literal.number(expression.this)
return expression_class(
- this=args[0],
- expression=expression,
- unit=exp.Literal.string(interval.text("unit")),
+ this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
return func
@@ -462,9 +457,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
+ this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
)
@@ -546,13 +539,21 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
- if time_format and time_format not in (_dialect.time_format, _dialect.date_format):
+ if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
return f"CAST({self.sql(expression, 'this')} AS DATE)"
return _ts_or_ds_to_date_sql
+def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str:
+ this, *rest_args = expression.expressions
+ for arg in rest_args:
+ this = exp.DPipe(this=this, expression=arg)
+
+ return self.sql(this)
+
+
# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator
def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]:
names = []
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 924b979..3cca986 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -16,21 +16,10 @@ from sqlglot.dialects.dialect import (
)
-def _str_to_time_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
- return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
-
-
-def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
- time_format = self.format_time(expression)
- if time_format and time_format not in (Drill.time_format, Drill.date_format):
- return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
- return f"CAST({self.sql(expression, 'this')} AS DATE)"
-
-
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
- unit = exp.Var(this=expression.text("unit").upper() or "DAY")
+ unit = exp.var(expression.text("unit").upper() or "DAY")
return (
f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
)
@@ -41,19 +30,19 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
- if time_format == Drill.date_format:
+ if time_format == Drill.DATE_FORMAT:
return f"CAST({this} AS DATE)"
return f"TO_DATE({this}, {time_format})"
class Drill(Dialect):
- normalize_functions = None
- null_ordering = "nulls_are_last"
- date_format = "'yyyy-MM-dd'"
- dateint_format = "'yyyyMMdd'"
- time_format = "'yyyy-MM-dd HH:mm:ss'"
+ NORMALIZE_FUNCTIONS: bool | str = False
+ NULL_ORDERING = "nulls_are_last"
+ DATE_FORMAT = "'yyyy-MM-dd'"
+ DATEINT_FORMAT = "'yyyyMMdd'"
+ TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
- time_mapping = {
+ TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
@@ -93,6 +82,7 @@ class Drill(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
+ CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -135,8 +125,8 @@ class Drill(Dialect):
exp.DateAdd: _date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("SUB"),
- exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
- exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
+ exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
+ exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@@ -154,7 +144,7 @@ class Drill(Dialect):
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})",
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f31da73..f0c1820 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -56,11 +56,7 @@ def _sort_array_reverse(args: t.List) -> exp.Expression:
def _parse_date_diff(args: t.List) -> exp.Expression:
- return exp.DateDiff(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
- )
+ return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
@@ -90,7 +86,7 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
- null_ordering = "nulls_are_last"
+ NULL_ORDERING = "nulls_are_last"
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@@ -118,6 +114,8 @@ class DuckDB(Dialect):
}
class Parser(parser.Parser):
+ CONCAT_NULL_OUTPUTS_STRING = True
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
@@ -127,10 +125,7 @@ class DuckDB(Dialect):
"DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
- this=exp.Div(
- this=seq_get(args, 0),
- expression=exp.Literal.number(1000),
- )
+ this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000))
),
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
@@ -191,8 +186,8 @@ class DuckDB(Dialect):
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
- exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
+ exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
+ exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)",
exp.Explode: rename_func("UNNEST"),
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.JSONExtract: arrow_json_extract_sql,
@@ -242,11 +237,27 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
+ UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
+
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def interval_sql(self, expression: exp.Interval) -> str:
+ multiplier: t.Optional[int] = None
+ unit = expression.text("unit").lower()
+
+ if unit.startswith("week"):
+ multiplier = 7
+ if unit.startswith("quarter"):
+ multiplier = 90
+
+ if multiplier:
+ return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
+
+ return super().interval_sql(expression)
+
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
) -> str:
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 650a1e1..8847119 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -80,12 +80,12 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
+
return f"{diff_sql}{multiplier_sql}"
def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
-
if not this.type:
from sqlglot.optimizer.annotate_types import annotate_types
@@ -113,7 +113,7 @@ def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> st
def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
- if time_format not in (Hive.time_format, Hive.date_format):
+ if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
return f"CAST({this} AS DATE)"
@@ -121,7 +121,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
- if time_format not in (Hive.time_format, Hive.date_format):
+ if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
return f"CAST({this} AS TIMESTAMP)"
@@ -130,7 +130,7 @@ def _time_format(
self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
) -> t.Optional[str]:
time_format = self.format_time(expression)
- if time_format == Hive.time_format:
+ if time_format == Hive.TIME_FORMAT:
return None
return time_format
@@ -144,16 +144,16 @@ def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
- if time_format and time_format not in (Hive.time_format, Hive.date_format):
+ if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
return f"TO_DATE({this}, {time_format})"
return f"TO_DATE({this})"
class Hive(Dialect):
- alias_post_tablesample = True
- identifiers_can_start_with_digit = True
+ ALIAS_POST_TABLESAMPLE = True
+ IDENTIFIERS_CAN_START_WITH_DIGIT = True
- time_mapping = {
+ TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
"YYYY": "%Y",
@@ -184,9 +184,9 @@ class Hive(Dialect):
"EEEE": "%A",
}
- date_format = "'yyyy-MM-dd'"
- dateint_format = "'yyyyMMdd'"
- time_format = "'yyyy-MM-dd HH:mm:ss'"
+ DATE_FORMAT = "'yyyy-MM-dd'"
+ DATEINT_FORMAT = "'yyyyMMdd'"
+ TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'"
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
@@ -224,9 +224,7 @@ class Hive(Dialect):
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- unit=exp.Literal.string("DAY"),
+ this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
"DATEDIFF": lambda args: exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
@@ -234,10 +232,7 @@ class Hive(Dialect):
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
- expression=exp.Mul(
- this=seq_get(args, 1),
- expression=exp.Literal.number(-1),
- ),
+ expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
@@ -349,8 +344,8 @@ class Hive(Dialect):
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"),
exp.DateSub: _add_date_sql,
- exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
- exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
+ exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)",
+ exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
@@ -415,10 +410,7 @@ class Hive(Dialect):
)
def with_properties(self, properties: exp.Properties) -> str:
- return self.properties(
- properties,
- prefix=self.seg("TBLPROPERTIES"),
- )
+ return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
def datatype_sql(self, expression: exp.DataType) -> str:
if (
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 75023ff..d2462e1 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -94,10 +94,10 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
- time_format = "'%Y-%m-%d %T'"
+ TIME_FORMAT = "'%Y-%m-%d %T'"
# https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
- time_mapping = {
+ TIME_MAPPING = {
"%M": "%B",
"%c": "%-m",
"%e": "%-d",
@@ -128,6 +128,7 @@ class MySQL(Dialect):
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
"SEPARATOR": TokenType.SEPARATOR,
+ "ENUM": TokenType.ENUM,
"START": TokenType.BEGIN,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@@ -279,6 +280,16 @@ class MySQL(Dialect):
"SWAPS",
}
+ TYPE_TOKENS = {
+ *parser.Parser.TYPE_TOKENS,
+ TokenType.SET,
+ }
+
+ ENUM_TYPE_TOKENS = {
+ *parser.Parser.ENUM_TYPE_TOKENS,
+ TokenType.SET,
+ }
+
LOG_DEFAULTS_TO_LN = True
def _parse_show_mysql(
@@ -372,12 +383,7 @@ class MySQL(Dialect):
else:
collate = None
- return self.expression(
- exp.SetItem,
- this=charset,
- collate=collate,
- kind="NAMES",
- )
+ return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@@ -472,9 +478,7 @@ class MySQL(Dialect):
def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str:
sql = self.sql(expression, arg)
- if not sql:
- return ""
- return f" {prefix} {sql}"
+ return f" {prefix} {sql}" if sql else ""
def _oldstyle_limit_sql(self, expression: exp.Show) -> str:
limit = self.sql(expression, "limit")
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 7722753..8d35e92 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -24,21 +24,15 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
if self._match_text_seq("COLUMNS"):
columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
- return self.expression(
- exp.XMLTable,
- this=this,
- passing=passing,
- columns=columns,
- by_ref=by_ref,
- )
+ return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
class Oracle(Dialect):
- alias_post_tablesample = True
+ ALIAS_POST_TABLESAMPLE = True
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
- time_mapping = {
+ TIME_MAPPING = {
"AM": "%p", # Meridian indicator with or without periods
"A.M.": "%p", # Meridian indicator with or without periods
"PM": "%p", # Meridian indicator with or without periods
@@ -87,7 +81,7 @@ class Oracle(Dialect):
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
- def _parse_hint(self) -> t.Optional[exp.Expression]:
+ def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
start = self._curr
while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
@@ -129,7 +123,7 @@ class Oracle(Dialect):
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
- exp.IfNull: rename_func("NVL"),
+ exp.Coalesce: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
@@ -179,7 +173,6 @@ class Oracle(Dialect):
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
- "RETURNING": TokenType.RETURNING,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 8d84024..8c2a4ab 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -183,9 +183,10 @@ def _to_timestamp(args: t.List) -> exp.Expression:
class Postgres(Dialect):
- null_ordering = "nulls_are_large"
- time_format = "'YYYY-MM-DD HH24:MI:SS'"
- time_mapping = {
+ INDEX_OFFSET = 1
+ NULL_ORDERING = "nulls_are_large"
+ TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
+ TIME_MAPPING = {
"AM": "%p",
"PM": "%p",
"D": "%u", # 1-based day of week
@@ -241,7 +242,6 @@ class Postgres(Dialect):
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
- "RETURNING": TokenType.RETURNING,
"REVOKE": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
@@ -258,6 +258,7 @@ class Postgres(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
+ CONCAT_NULL_OUTPUTS_STRING = True
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -268,6 +269,7 @@ class Postgres(Dialect):
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
+ "UNNEST": exp.Explode.from_arg_list,
}
FUNCTION_PARSERS = {
@@ -303,7 +305,7 @@ class Postgres(Dialect):
value = self._parse_bitwise()
if part and part.is_string:
- part = exp.Var(this=part.name)
+ part = exp.var(part.name)
return self.expression(exp.Extract, this=part, expression=value)
@@ -328,6 +330,7 @@ class Postgres(Dialect):
**generator.Generator.TRANSFORMS,
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
+ exp.Explode: rename_func("UNNEST"),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index d839864..a8a9884 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -102,7 +102,7 @@ def _str_to_time_sql(
def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
- if time_format and time_format not in (Presto.time_format, Presto.date_format):
+ if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
@@ -119,7 +119,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
exp.Literal.number(1),
exp.Literal.number(10),
),
- Presto.date_format,
+ Presto.DATE_FORMAT,
)
return self.func(
@@ -145,9 +145,7 @@ def _approx_percentile(args: t.List) -> exp.Expression:
)
if len(args) == 3:
return exp.ApproxQuantile(
- this=seq_get(args, 0),
- quantile=seq_get(args, 1),
- accuracy=seq_get(args, 2),
+ this=seq_get(args, 0), quantile=seq_get(args, 1), accuracy=seq_get(args, 2)
)
return exp.ApproxQuantile.from_arg_list(args)
@@ -160,10 +158,8 @@ def _from_unixtime(args: t.List) -> exp.Expression:
minutes=seq_get(args, 2),
)
if len(args) == 2:
- return exp.UnixToTime(
- this=seq_get(args, 0),
- zone=seq_get(args, 1),
- )
+ return exp.UnixToTime(this=seq_get(args, 0), zone=seq_get(args, 1))
+
return exp.UnixToTime.from_arg_list(args)
@@ -173,21 +169,17 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
unnest = exp.Unnest(expressions=[expression.this])
if expression.alias:
- return exp.alias_(
- unnest,
- alias="_u",
- table=[expression.alias],
- copy=False,
- )
+ return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
return unnest
return expression
class Presto(Dialect):
- index_offset = 1
- null_ordering = "nulls_are_last"
- time_format = MySQL.time_format
- time_mapping = MySQL.time_mapping
+ INDEX_OFFSET = 1
+ NULL_ORDERING = "nulls_are_last"
+ TIME_FORMAT = MySQL.TIME_FORMAT
+ TIME_MAPPING = MySQL.TIME_MAPPING
+ STRICT_STRING_CONCAT = True
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
@@ -205,14 +197,10 @@ class Presto(Dialect):
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_DIFF": lambda args: exp.DateDiff(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@@ -225,9 +213,7 @@ class Presto(Dialect):
"NOW": exp.CurrentTimestamp.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
- this=seq_get(args, 0),
- substr=seq_get(args, 1),
- instance=seq_get(args, 2),
+ this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
@@ -242,7 +228,7 @@ class Presto(Dialect):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
- IS_BOOL = False
+ IS_BOOL_ALLOWED = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@@ -284,10 +270,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
- exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
- exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
+ exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
+ exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.Decode: _decode_sql,
- exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
+ exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: _encode_sql,
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.Group: transforms.preprocess([transforms.unalias_group]),
@@ -322,7 +308,7 @@ class Presto(Dialect):
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
+ exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
@@ -367,8 +353,16 @@ class Presto(Dialect):
to = target_type.copy()
if target_type is start.to:
- end = exp.Cast(this=end, to=to)
+ end = exp.cast(end, to)
else:
- start = exp.Cast(this=start, to=to)
+ start = exp.cast(start, to)
return self.func("SEQUENCE", start, end, step)
+
+ def offset_limit_modifiers(
+ self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
+ ) -> t.List[str]:
+ return [
+ self.sql(expression, "offset"),
+ self.sql(limit),
+ ]
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index b0a6774..a7e25fa 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
-from sqlglot.dialects.dialect import rename_func
+from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -14,9 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx
class Redshift(Postgres):
- time_format = "'YYYY-MM-DD HH:MI:SS'"
- time_mapping = {
- **Postgres.time_mapping,
+ TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
+ TIME_MAPPING = {
+ **Postgres.TIME_MAPPING,
"MON": "%b",
"HH": "%H",
}
@@ -51,7 +51,7 @@ class Redshift(Postgres):
and this.expressions
and this.expressions[0].this == exp.column("MAX")
):
- this.set("expressions", [exp.Var(this="MAX")])
+ this.set("expressions", [exp.var("MAX")])
return this
@@ -94,6 +94,7 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS,
+ exp.Concat: concat_to_dpipe_sql,
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@@ -106,6 +107,7 @@ class Redshift(Postgres):
exp.FromBase: rename_func("STRTOL"),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
+ exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
@@ -170,6 +172,6 @@ class Redshift(Postgres):
precision = expression.args.get("expressions")
if not precision:
- expression.append("expressions", exp.Var(this="MAX"))
+ expression.append("expressions", exp.var("MAX"))
return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 821d991..148b6d8 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -167,10 +167,10 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression:
class Snowflake(Dialect):
- null_ordering = "nulls_are_large"
- time_format = "'yyyy-mm-dd hh24:mi:ss'"
+ NULL_ORDERING = "nulls_are_large"
+ TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
- time_mapping = {
+ TIME_MAPPING = {
"YYYY": "%Y",
"yyyy": "%Y",
"YY": "%y",
@@ -210,14 +210,10 @@ class Snowflake(Dialect):
"CONVERT_TIMEZONE": _parse_convert_timezone,
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DATEDIFF": lambda args: exp.DateDiff(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=seq_get(args, 0),
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
@@ -246,9 +242,7 @@ class Snowflake(Dialect):
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
- exp.Bracket,
- this=this,
- expressions=[path],
+ exp.Bracket, this=this, expressions=[path]
),
}
@@ -275,6 +269,7 @@ class Snowflake(Dialect):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
+ COMMENTS = ["--", "//", ("/*", "*/")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index bf24240..ed6992d 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -38,7 +38,7 @@ def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
- if time_format == Hive.date_format:
+ if time_format == Hive.DATE_FORMAT:
return f"TO_DATE({this})"
return f"TO_DATE({this}, {time_format})"
@@ -133,13 +133,13 @@ class Spark2(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
- "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1),
unit=exp.var(seq_get(args, 0)),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"BOOLEAN": _parse_as_cast("boolean"),
+ "DATE": _parse_as_cast("date"),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
"INT": _parse_as_cast("int"),
@@ -162,11 +162,9 @@ class Spark2(Hive):
def _parse_add_column(self) -> t.Optional[exp.Expression]:
return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
- def _parse_drop_column(self) -> t.Optional[exp.Expression]:
+ def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
return self._match_text_seq("DROP", "COLUMNS") and self.expression(
- exp.Drop,
- this=self._parse_schema(),
- kind="COLUMNS",
+ exp.Drop, this=self._parse_schema(), kind="COLUMNS"
)
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 4e800b0..3b837ea 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
+ concat_to_dpipe_sql,
count_if_to_sum,
no_ilike_sql,
no_pivot_sql,
@@ -62,10 +63,6 @@ class SQLite(Dialect):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
- KEYWORDS = {
- **tokens.Tokenizer.KEYWORDS,
- }
-
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -100,6 +97,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Concat: concat_to_dpipe_sql,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
@@ -116,6 +114,7 @@ class SQLite(Dialect):
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Pivot: no_pivot_sql,
+ exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_qualify]
),
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index d5fba17..67ef76b 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, transforms
-from sqlglot.dialects.dialect import Dialect
+from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
@@ -11,6 +11,7 @@ class Tableau(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Coalesce: rename_func("IFNULL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
@@ -25,9 +26,6 @@ class Tableau(Dialect):
false = self.sql(expression, "false")
return f"IF {this} THEN {true} ELSE {false} END"
- def coalesce_sql(self, expression: exp.Coalesce) -> str:
- return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
-
def count_sql(self, expression: exp.Count) -> str:
this = expression.this
if isinstance(this, exp.Distinct):
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 514aecb..d5e5dd8 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -1,18 +1,32 @@
from __future__ import annotations
-import typing as t
-
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot.dialects.dialect import (
- Dialect,
- format_time_lambda,
- max_or_greatest,
- min_or_least,
-)
+from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
class Teradata(Dialect):
+ TIME_MAPPING = {
+ "Y": "%Y",
+ "YYYY": "%Y",
+ "YY": "%y",
+ "MMMM": "%B",
+ "MMM": "%b",
+ "DD": "%d",
+ "D": "%-d",
+ "HH": "%H",
+ "H": "%-H",
+ "MM": "%M",
+ "M": "%-M",
+ "SS": "%S",
+ "S": "%-S",
+ "SSSSSS": "%f",
+ "E": "%a",
+ "EE": "%a",
+ "EEE": "%a",
+ "EEEE": "%A",
+ }
+
class Tokenizer(tokens.Tokenizer):
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance
KEYWORDS = {
@@ -31,7 +45,7 @@ class Teradata(Dialect):
"ST_GEOMETRY": TokenType.GEOMETRY,
}
- # teradata does not support % for modulus
+ # Teradata does not support % as a modulo operator
SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS}
SINGLE_TOKENS.pop("%")
@@ -101,7 +115,7 @@ class Teradata(Dialect):
# FROM before SET in Teradata UPDATE syntax
# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
- def _parse_update(self) -> exp.Expression:
+ def _parse_update(self) -> exp.Update:
return self.expression(
exp.Update,
**{ # type: ignore
@@ -122,14 +136,6 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
- def _parse_cast(self, strict: bool) -> exp.Expression:
- cast = t.cast(exp.Cast, super()._parse_cast(strict))
- if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT):
- return format_time_lambda(exp.TimeToStr, "teradata")(
- [cast.this, self._parse_string()]
- )
- return cast
-
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
@@ -151,7 +157,7 @@ class Teradata(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
- exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
+ exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index f6ad888..6d674f5 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -64,9 +64,9 @@ def _format_time_lambda(
format=exp.Literal.string(
format_time(
args[0].name,
- {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
+ {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
- else TSQL.time_mapping,
+ else TSQL.TIME_MAPPING,
)
),
)
@@ -86,9 +86,9 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(
this=args[0],
format=exp.Literal.string(
- format_time(fmt.name, TSQL.format_time_mapping)
+ format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
if len(fmt.name) == 1
- else format_time(fmt.name, TSQL.time_mapping)
+ else format_time(fmt.name, TSQL.TIME_MAPPING)
),
)
@@ -138,7 +138,7 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
- expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
+ expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
)
)
)
@@ -166,10 +166,10 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
class TSQL(Dialect):
- null_ordering = "nulls_are_small"
- time_format = "'yyyy-mm-dd hh:mm:ss'"
+ NULL_ORDERING = "nulls_are_small"
+ TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
- time_mapping = {
+ TIME_MAPPING = {
"year": "%Y",
"qq": "%q",
"q": "%q",
@@ -213,7 +213,7 @@ class TSQL(Dialect):
"yy": "%y",
}
- convert_format_mapping = {
+ CONVERT_FORMAT_MAPPING = {
"0": "%b %d %Y %-I:%M%p",
"1": "%m/%d/%y",
"2": "%y.%m.%d",
@@ -253,8 +253,8 @@ class TSQL(Dialect):
"120": "%Y-%m-%d %H:%M:%S",
"121": "%Y-%m-%d %H:%M:%S.%f",
}
- # not sure if complete
- format_time_mapping = {
+
+ FORMAT_TIME_MAPPING = {
"y": "%B %Y",
"d": "%m/%d/%Y",
"H": "%-H",
@@ -312,9 +312,7 @@ class TSQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
+ this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@@ -363,6 +361,8 @@ class TSQL(Dialect):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
+ CONCAT_NULL_OUTPUTS_STRING = True
+
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@@ -400,7 +400,7 @@ class TSQL(Dialect):
table.set("system_time", self._parse_system_time())
return table
- def _parse_returns(self) -> exp.Expression:
+ def _parse_returns(self) -> exp.ReturnsProperty:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
returns.set("table", table)
@@ -423,12 +423,12 @@ class TSQL(Dialect):
format_val = self._parse_number()
format_val_name = format_val.name if format_val else ""
- if format_val_name not in TSQL.convert_format_mapping:
+ if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING:
raise ValueError(
f"CONVERT function at T-SQL does not support format style {format_val_name}"
)
- format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
+ format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name])
# Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE:
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 51cffbd..d2c4e72 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -151,6 +151,7 @@ ENV = {
"CAST": cast,
"COALESCE": lambda *args: next((a for a in args if a is not None), None),
"CONCAT": null_if_any(lambda *args: "".join(args)),
+ "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
@@ -159,7 +160,6 @@ ENV = {
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
- "IFNULL": lambda e, alt: alt if e is None else e,
"IF": lambda predicate, true, false: true if predicate else false,
"INTDIV": null_if_any(lambda e, this: e // this),
"INTERVAL": interval,
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index f114e5c..3f96f90 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -394,7 +394,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str:
names = {e.name.lower() for e in e.expressions}
e = e.transform(
- lambda n: exp.Var(this=n.name)
+ lambda n: exp.var(n.name)
if isinstance(n, exp.Identifier) and n.name.lower() in names
else n
)
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index da4a4ed..c7d4664 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1500,6 +1500,7 @@ class Index(Expression):
arg_types = {
"this": False,
"table": False,
+ "using": False,
"where": False,
"columns": False,
"unique": False,
@@ -1623,7 +1624,7 @@ class Lambda(Expression):
class Limit(Expression):
- arg_types = {"this": False, "expression": True}
+ arg_types = {"this": False, "expression": True, "offset": False}
class Literal(Condition):
@@ -1869,6 +1870,10 @@ class EngineProperty(Property):
arg_types = {"this": True}
+class ToTableProperty(Property):
+ arg_types = {"this": True}
+
+
class ExecuteAsProperty(Property):
arg_types = {"this": True}
@@ -3072,12 +3077,35 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
-
inst = _maybe_copy(self, copy)
inst.set("locks", [Lock(update=update)])
return inst
+ def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select:
+ """
+ Set hints for this expression.
+
+ Examples:
+ >>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark")
+ 'SELECT /*+ BROADCAST(y) */ x FROM tbl'
+
+ Args:
+ hints: The SQL code strings to parse as the hints.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect: The dialect used to parse the hints.
+ copy: If `False`, modify this expression instance in-place.
+
+ Returns:
+ The modified expression.
+ """
+ inst = _maybe_copy(self, copy)
+ inst.set(
+ "hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
+ )
+
+ return inst
+
@property
def named_selects(self) -> t.List[str]:
return [e.output_name for e in self.expressions if e.alias_or_name]
@@ -3244,6 +3272,7 @@ class DataType(Expression):
DATE = auto()
DATETIME = auto()
DATETIME64 = auto()
+ ENUM = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
@@ -3284,6 +3313,7 @@ class DataType(Expression):
OBJECT = auto()
ROWVERSION = auto()
SERIAL = auto()
+ SET = auto()
SMALLINT = auto()
SMALLMONEY = auto()
SMALLSERIAL = auto()
@@ -3334,6 +3364,7 @@ class DataType(Expression):
NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
TEMPORAL_TYPES = {
+ Type.TIME,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
@@ -3342,6 +3373,8 @@ class DataType(Expression):
Type.DATETIME64,
}
+ META_TYPES = {"UNKNOWN", "NULL"}
+
@classmethod
def build(
cls, dtype: str | DataType | DataType.Type, dialect: DialectType = None, **kwargs
@@ -3349,8 +3382,9 @@ class DataType(Expression):
from sqlglot import parse_one
if isinstance(dtype, str):
- if dtype.upper() in cls.Type.__members__:
- data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()])
+ upper = dtype.upper()
+ if upper in DataType.META_TYPES:
+ data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[upper])
else:
data_type_exp = parse_one(dtype, read=dialect, into=DataType)
@@ -3483,6 +3517,10 @@ class Dot(Binary):
def name(self) -> str:
return self.expression.name
+ @property
+ def output_name(self) -> str:
+ return self.name
+
@classmethod
def build(self, expressions: t.Sequence[Expression]) -> Dot:
"""Build a Dot object with a sequence of expressions."""
@@ -3502,6 +3540,10 @@ class DPipe(Binary):
pass
+class SafeDPipe(DPipe):
+ pass
+
+
class EQ(Binary, Predicate):
pass
@@ -3615,6 +3657,10 @@ class Not(Unary):
class Paren(Unary):
arg_types = {"this": True, "with": False}
+ @property
+ def output_name(self) -> str:
+ return self.this.name
+
class Neg(Unary):
pass
@@ -3904,6 +3950,7 @@ class Ceil(Func):
class Coalesce(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
+ _sql_names = ["COALESCE", "IFNULL", "NVL"]
class Concat(Func):
@@ -3911,12 +3958,17 @@ class Concat(Func):
is_var_len_args = True
+class SafeConcat(Concat):
+ pass
+
+
class ConcatWs(Concat):
_sql_names = ["CONCAT_WS"]
class Count(AggFunc):
- arg_types = {"this": False}
+ arg_types = {"this": False, "expressions": False}
+ is_var_len_args = True
class CountIf(AggFunc):
@@ -4049,6 +4101,11 @@ class DateToDi(Func):
pass
+class Date(Func):
+ arg_types = {"expressions": True}
+ is_var_len_args = True
+
+
class Day(Func):
pass
@@ -4102,11 +4159,6 @@ class If(Func):
arg_types = {"this": True, "true": True, "false": False}
-class IfNull(Func):
- arg_types = {"this": True, "expression": False}
- _sql_names = ["IFNULL", "NVL"]
-
-
class Initcap(Func):
arg_types = {"this": True, "expression": False}
@@ -5608,22 +5660,27 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -
expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
-def column_table_names(expression: Expression) -> t.List[str]:
+def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]:
"""
Return all table names referenced through columns in an expression.
Example:
>>> import sqlglot
- >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))
- ['c', 'a']
+ >>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")))
+ ['a', 'c']
Args:
expression: expression to find table names.
+ exclude: a table name to exclude
Returns:
A list of unique names.
"""
- return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
+ return {
+ table
+ for table in (column.table for column in expression.find_all(Column))
+ if table and table != exclude
+ }
def table_name(table: Table | str) -> str:
@@ -5649,12 +5706,13 @@ def table_name(table: Table | str) -> str:
return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part)
-def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
+def replace_tables(expression: E, mapping: t.Dict[str, str], copy: bool = True) -> E:
"""Replace all tables in expression according to the mapping.
Args:
expression: expression node to be transformed and replaced.
mapping: mapping of table names.
+ copy: whether or not to copy the expression.
Examples:
>>> from sqlglot import exp, parse_one
@@ -5675,7 +5733,7 @@ def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E:
)
return node
- return expression.transform(_replace_tables)
+ return expression.transform(_replace_tables, copy=copy)
def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression:
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 97cbe15..d3cf9f0 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -14,47 +14,32 @@ logger = logging.getLogger("sqlglot")
class Generator:
"""
- Generator interprets the given syntax tree and produces a SQL string as an output.
+ Generator converts a given syntax tree to the corresponding SQL string.
Args:
- time_mapping (dict): the dictionary of custom time mappings in which the key
- represents a python time format and the output the target time format
- time_trie (trie): a trie of the time_mapping keys
- pretty (bool): if set to True the returned string will be formatted. Default: False.
- quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
- quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
- identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
- identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
- bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
- bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
- hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
- hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
- byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
- byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
- raw_start (str): specifies which starting character to use to delimit raw literals. Default: None.
- raw_end (str): specifies which ending character to use to delimit raw literals. Default: None.
- identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
- normalize (bool): if set to True all identifiers will lower cased
- string_escape (str): specifies a string escape character. Default: '.
- identifier_escape (str): specifies an identifier escape character. Default: ".
- pad (int): determines padding in a formatted string. Default: 2.
- indent (int): determines the size of indentation in a formatted string. Default: 4.
- unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
- normalize_functions (str): normalize function names, "upper", "lower", or None
- Default: "upper"
- alias_post_tablesample (bool): if the table alias comes after tablesample
- Default: False
- identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
- Default: False
- unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
- unsupported expressions. Default ErrorLevel.WARN.
- null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
- Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
- Default: "nulls_are_small"
- max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
+ pretty: Whether or not to format the produced SQL string.
+ Default: False.
+ identify: Determines when an identifier should be quoted. Possible values are:
+ False (default): Never quote, except in cases where it's mandatory by the dialect.
+ True or 'always': Always quote.
+ 'safe': Only quote identifiers that are case insensitive.
+ normalize: Whether or not to normalize identifiers to lowercase.
+ Default: False.
+ pad: Determines the pad size in a formatted string.
+ Default: 2.
+ indent: Determines the indentation size in a formatted string.
+ Default: 2.
+ normalize_functions: Whether or not to normalize all function names. Possible values are:
+ "upper" or True (default): Convert names to uppercase.
+ "lower": Convert names to lowercase.
+ False: Disables function name normalization.
+ unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
+ Default ErrorLevel.WARN.
+ max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
- leading_comma (bool): if the the comma is leading or trailing in select statements
+ leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
+ This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
@@ -86,6 +71,7 @@ class Generator:
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
+ exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransientProperty: lambda self, e: "TRANSIENT",
exp.StabilityProperty: lambda self, e: e.name,
exp.VolatileProperty: lambda self, e: "VOLATILE",
@@ -138,15 +124,24 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
- # Whether a table is allowed to be renamed with a db
+ # Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
GROUPINGS_SEP = ","
- # The string used for creating index on a table
+ # The string used for creating an index on a table
INDEX_ON = "ON"
+ # Whether or not join hints should be generated
+ JOIN_HINTS = True
+
+ # Whether or not table hints should be generated
+ TABLE_HINTS = True
+
+ # Whether or not comparing against booleans (e.g. x IS TRUE) is supported
+ IS_BOOL_ALLOWED = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -228,6 +223,7 @@ class Generator:
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
+ exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
@@ -235,128 +231,110 @@ class Generator:
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
- JOIN_HINTS = True
- TABLE_HINTS = True
- IS_BOOL = True
-
+ # Keywords that can't be used as unquoted identifier names
RESERVED_KEYWORDS: t.Set[str] = set()
- WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
- UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren)
+
+ # Expressions whose comments are separated from them for better formatting
+ WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Select,
+ exp.From,
+ exp.Where,
+ exp.With,
+ )
+
+ # Expressions that can remain unwrapped when appearing in the context of an INTERVAL
+ UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Column,
+ exp.Literal,
+ exp.Neg,
+ exp.Paren,
+ )
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
+ # Autofilled
+ INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
+ INVERSE_TIME_TRIE: t.Dict = {}
+ INDEX_OFFSET = 0
+ UNNEST_COLUMN_ONLY = False
+ ALIAS_POST_TABLESAMPLE = False
+ IDENTIFIERS_CAN_START_WITH_DIGIT = False
+ STRICT_STRING_CONCAT = False
+ NORMALIZE_FUNCTIONS: bool | str = "upper"
+ NULL_ORDERING = "nulls_are_small"
+
+ # Delimiters for quotes, identifiers and the corresponding escape characters
+ QUOTE_START = "'"
+ QUOTE_END = "'"
+ IDENTIFIER_START = '"'
+ IDENTIFIER_END = '"'
+ STRING_ESCAPE = "'"
+ IDENTIFIER_ESCAPE = '"'
+
+ # Delimiters for bit, hex, byte and raw literals
+ BIT_START: t.Optional[str] = None
+ BIT_END: t.Optional[str] = None
+ HEX_START: t.Optional[str] = None
+ HEX_END: t.Optional[str] = None
+ BYTE_START: t.Optional[str] = None
+ BYTE_END: t.Optional[str] = None
+ RAW_START: t.Optional[str] = None
+ RAW_END: t.Optional[str] = None
+
__slots__ = (
- "time_mapping",
- "time_trie",
"pretty",
- "quote_start",
- "quote_end",
- "identifier_start",
- "identifier_end",
- "bit_start",
- "bit_end",
- "hex_start",
- "hex_end",
- "byte_start",
- "byte_end",
- "raw_start",
- "raw_end",
"identify",
"normalize",
- "string_escape",
- "identifier_escape",
"pad",
- "index_offset",
- "unnest_column_only",
- "alias_post_tablesample",
- "identifiers_can_start_with_digit",
+ "_indent",
"normalize_functions",
"unsupported_level",
- "unsupported_messages",
- "null_ordering",
"max_unsupported",
- "_indent",
+ "leading_comma",
+ "max_text_width",
+ "comments",
+ "unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
- "_leading_comma",
- "_max_text_width",
- "_comments",
"_cache",
)
def __init__(
self,
- time_mapping=None,
- time_trie=None,
- pretty=None,
- quote_start=None,
- quote_end=None,
- identifier_start=None,
- identifier_end=None,
- bit_start=None,
- bit_end=None,
- hex_start=None,
- hex_end=None,
- byte_start=None,
- byte_end=None,
- raw_start=None,
- raw_end=None,
- identify=False,
- normalize=False,
- string_escape=None,
- identifier_escape=None,
- pad=2,
- indent=2,
- index_offset=0,
- unnest_column_only=False,
- alias_post_tablesample=False,
- identifiers_can_start_with_digit=False,
- normalize_functions="upper",
- unsupported_level=ErrorLevel.WARN,
- null_ordering=None,
- max_unsupported=3,
- leading_comma=False,
- max_text_width=80,
- comments=True,
+ pretty: t.Optional[bool] = None,
+ identify: str | bool = False,
+ normalize: bool = False,
+ pad: int = 2,
+ indent: int = 2,
+ normalize_functions: t.Optional[str | bool] = None,
+ unsupported_level: ErrorLevel = ErrorLevel.WARN,
+ max_unsupported: int = 3,
+ leading_comma: bool = False,
+ max_text_width: int = 80,
+ comments: bool = True,
):
import sqlglot
- self.time_mapping = time_mapping or {}
- self.time_trie = time_trie
self.pretty = pretty if pretty is not None else sqlglot.pretty
- self.quote_start = quote_start or "'"
- self.quote_end = quote_end or "'"
- self.identifier_start = identifier_start or '"'
- self.identifier_end = identifier_end or '"'
- self.bit_start = bit_start
- self.bit_end = bit_end
- self.hex_start = hex_start
- self.hex_end = hex_end
- self.byte_start = byte_start
- self.byte_end = byte_end
- self.raw_start = raw_start
- self.raw_end = raw_end
self.identify = identify
self.normalize = normalize
- self.string_escape = string_escape or "'"
- self.identifier_escape = identifier_escape or '"'
self.pad = pad
- self.index_offset = index_offset
- self.unnest_column_only = unnest_column_only
- self.alias_post_tablesample = alias_post_tablesample
- self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
- self.normalize_functions = normalize_functions
+ self._indent = indent
self.unsupported_level = unsupported_level
- self.unsupported_messages = []
self.max_unsupported = max_unsupported
- self.null_ordering = null_ordering
- self._indent = indent
- self._escaped_quote_end = self.string_escape + self.quote_end
- self._escaped_identifier_end = self.identifier_escape + self.identifier_end
- self._leading_comma = leading_comma
- self._max_text_width = max_text_width
- self._comments = comments
- self._cache = None
+ self.leading_comma = leading_comma
+ self.max_text_width = max_text_width
+ self.comments = comments
+
+ # This is both a Dialect property and a Generator argument, so we prioritize the latter
+ self.normalize_functions = (
+ self.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions
+ )
+
+ self.unsupported_messages: t.List[str] = []
+ self._escaped_quote_end: str = self.STRING_ESCAPE + self.QUOTE_END
+ self._escaped_identifier_end: str = self.IDENTIFIER_ESCAPE + self.IDENTIFIER_END
+ self._cache: t.Optional[t.Dict[int, str]] = None
def generate(
self,
@@ -364,17 +342,19 @@ class Generator:
cache: t.Optional[t.Dict[int, str]] = None,
) -> str:
"""
- Generates a SQL string by interpreting the given syntax tree.
+ Generates the SQL string corresponding to the given syntax tree.
- Args
- expression: the syntax tree.
- cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
+ Args:
+ expression: The syntax tree.
+ cache: An optional sql string cache. This leverages the hash of an Expression
+ which can be slow to compute, so only use it if you set _hash on each node.
- Returns
- the SQL string.
+ Returns:
+ The SQL string corresponding to `expression`.
"""
if cache is not None:
self._cache = cache
+
self.unsupported_messages = []
sql = self.sql(expression).strip()
self._cache = None
@@ -414,7 +394,11 @@ class Generator:
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
- comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
+ comments = (
+ ((expression and expression.comments) if comments is None else comments) # type: ignore
+ if self.comments
+ else None
+ )
if not comments or isinstance(expression, exp.Binary):
return sql
@@ -454,7 +438,7 @@ class Generator:
return result
def normalize_func(self, name: str) -> str:
- if self.normalize_functions == "upper":
+ if self.normalize_functions == "upper" or self.normalize_functions is True:
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
@@ -522,7 +506,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
+ sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
if self._cache is not None:
self._cache[expression_id] = sql
@@ -770,25 +754,25 @@ class Generator:
def bitstring_sql(self, expression: exp.BitString) -> str:
this = self.sql(expression, "this")
- if self.bit_start:
- return f"{self.bit_start}{this}{self.bit_end}"
+ if self.BIT_START:
+ return f"{self.BIT_START}{this}{self.BIT_END}"
return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
this = self.sql(expression, "this")
- if self.hex_start:
- return f"{self.hex_start}{this}{self.hex_end}"
+ if self.HEX_START:
+ return f"{self.HEX_START}{this}{self.HEX_END}"
return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
this = self.sql(expression, "this")
- if self.byte_start:
- return f"{self.byte_start}{this}{self.byte_end}"
+ if self.BYTE_START:
+ return f"{self.BYTE_START}{this}{self.BYTE_END}"
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
- if self.raw_start:
- return f"{self.raw_start}{expression.name}{self.raw_end}"
+ if self.RAW_START:
+ return f"{self.RAW_START}{expression.name}{self.RAW_END}"
return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\")))
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
@@ -883,24 +867,27 @@ class Generator:
name = f"{expression.name} " if expression.name else ""
table = self.sql(expression, "table")
table = f"{self.INDEX_ON} {table} " if table else ""
+ using = self.sql(expression, "using")
+ using = f"USING {using} " if using else ""
index = "INDEX " if not table else ""
columns = self.expressions(expression, key="columns", flat=True)
+ columns = f"({columns})" if columns else ""
partition_by = self.expressions(expression, key="partition_by", flat=True)
partition_by = f" PARTITION BY {partition_by}" if partition_by else ""
- return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}"
+ return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{partition_by}"
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
lower = text.lower()
text = lower if self.normalize and not expression.quoted else text
- text = text.replace(self.identifier_end, self._escaped_identifier_end)
+ text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
- or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
+ or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
- text = f"{self.identifier_start}{text}{self.identifier_end}"
+ text = f"{self.IDENTIFIER_START}{text}{self.IDENTIFIER_END}"
return text
def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
@@ -1197,7 +1184,7 @@ class Generator:
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
) -> str:
- if self.alias_post_tablesample and expression.this.alias:
+ if self.ALIAS_POST_TABLESAMPLE and expression.this.alias:
table = expression.this.copy()
table.set("alias", None)
this = self.sql(table)
@@ -1372,7 +1359,15 @@ class Generator:
def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
- return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
+ args = ", ".join(
+ sql
+ for sql in (
+ self.sql(expression, "offset"),
+ self.sql(expression, "expression"),
+ )
+ if sql
+ )
+ return f"{this}{self.seg('LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@@ -1418,10 +1413,10 @@ class Generator:
def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
- text = text.replace(self.quote_end, self._escaped_quote_end)
+ text = text.replace(self.QUOTE_END, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
- text = f"{self.quote_start}{text}{self.quote_end}"
+ text = f"{self.QUOTE_START}{text}{self.QUOTE_END}"
return text
def loaddata_sql(self, expression: exp.LoadData) -> str:
@@ -1463,9 +1458,9 @@ class Generator:
nulls_first = expression.args.get("nulls_first")
nulls_last = not nulls_first
- nulls_are_large = self.null_ordering == "nulls_are_large"
- nulls_are_small = self.null_ordering == "nulls_are_small"
- nulls_are_last = self.null_ordering == "nulls_are_last"
+ nulls_are_large = self.NULL_ORDERING == "nulls_are_large"
+ nulls_are_small = self.NULL_ORDERING == "nulls_are_small"
+ nulls_are_last = self.NULL_ORDERING == "nulls_are_last"
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
@@ -1521,7 +1516,7 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
- limit = expression.args.get("limit")
+ limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=limit.args.get("count"))
@@ -1540,12 +1535,19 @@ class Generator:
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
- self.sql(expression, "offset") if fetch else self.sql(limit),
- self.sql(limit) if fetch else self.sql(expression, "offset"),
+ *self.offset_limit_modifiers(expression, fetch, limit),
*self.after_limit_modifiers(expression),
sep="",
)
+ def offset_limit_modifiers(
+ self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit]
+ ) -> t.List[str]:
+ return [
+ self.sql(expression, "offset") if fetch else self.sql(limit),
+ self.sql(limit) if fetch else self.sql(expression, "offset"),
+ ]
+
def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
@@ -1634,7 +1636,7 @@ class Generator:
def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
- if alias and self.unnest_column_only:
+ if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
alias = self.sql(columns[0]) if columns else ""
else:
@@ -1697,7 +1699,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
- expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
+ expressions = apply_index_offset(expression.this, expression.expressions, self.INDEX_OFFSET)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@@ -1729,7 +1731,7 @@ class Generator:
statements.append("END")
- if self.pretty and self.text_width(statements) > self._max_text_width:
+ if self.pretty and self.text_width(statements) > self.max_text_width:
return self.indent("\n".join(statements), skip_first=True, skip_last=True)
return " ".join(statements)
@@ -1759,10 +1761,11 @@ class Generator:
else:
return self.func("TRIM", expression.this, expression.expression)
- def concat_sql(self, expression: exp.Concat) -> str:
- if len(expression.expressions) == 1:
- return self.sql(expression.expressions[0])
- return self.function_fallback_sql(expression)
+ def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
+ expressions = expression.expressions
+ if self.STRICT_STRING_CONCAT:
+ expressions = (exp.cast(e, "text") for e in expressions)
+ return self.func("CONCAT", *expressions)
def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
@@ -1785,9 +1788,7 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
- return self.case_sql(
- exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
- )
+ return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
modifier = expression.args.get("modifier")
@@ -1798,7 +1799,6 @@ class Generator:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
def jsonobject_sql(self, expression: exp.JSONObject) -> str:
- expressions = self.expressions(expression)
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
unique_keys = expression.args.get("unique_keys")
@@ -1811,7 +1811,11 @@ class Generator:
format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
- return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
+ return self.func(
+ "JSON_OBJECT",
+ *expression.expressions,
+ suffix=f"{null_handling}{unique_keys}{return_type}{format_json}{encoding})",
+ )
def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str:
this = self.sql(expression, "this")
@@ -1930,7 +1934,7 @@ class Generator:
for i, e in enumerate(expression.flatten(unnest=False))
)
- sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
+ sep = "\n" if self.text_width(sqls) > self.max_text_width else " "
return f"{sep}{op} ".join(sqls)
def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
@@ -2093,6 +2097,11 @@ class Generator:
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
+ def safedpipe_sql(self, expression: exp.SafeDPipe) -> str:
+ if self.STRICT_STRING_CONCAT:
+ return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
+ return self.dpipe_sql(expression)
+
def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
@@ -2127,7 +2136,7 @@ class Generator:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
- if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
+ if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean):
return self.sql(
expression.this if expression.expression.this else exp.not_(expression.this)
)
@@ -2197,12 +2206,18 @@ class Generator:
return self.func(expression.sql_name(), *args)
- def func(self, name: str, *args: t.Optional[exp.Expression | str]) -> str:
- return f"{self.normalize_func(name)}({self.format_args(*args)})"
+ def func(
+ self,
+ name: str,
+ *args: t.Optional[exp.Expression | str],
+ prefix: str = "(",
+ suffix: str = ")",
+ ) -> str:
+ return f"{self.normalize_func(name)}{prefix}{self.format_args(*args)}{suffix}"
def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
- if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
+ if self.pretty and self.text_width(arg_sqls) > self.max_text_width:
return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
return ", ".join(arg_sqls)
@@ -2210,7 +2225,9 @@ class Generator:
return sum(len(arg) for arg in args)
def format_time(self, expression: exp.Expression) -> t.Optional[str]:
- return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
+ return format_time(
+ self.sql(expression, "format"), self.INVERSE_TIME_MAPPING, self.INVERSE_TIME_TRIE
+ )
def expressions(
self,
@@ -2242,7 +2259,7 @@ class Generator:
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
- if self._leading_comma:
+ if self.leading_comma:
result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}")
else:
result_sqls.append(
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 4215fee..2f48ab5 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -208,7 +208,7 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) ->
return expression
-def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
+def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]:
"""
Sorts a given directed acyclic graph in topological order.
@@ -220,22 +220,24 @@ def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
"""
result = []
- def visit(node: T, visited: t.Set[T]) -> None:
- if node in result:
- return
- if node in visited:
- raise ValueError("Cycle error")
+ for node, deps in tuple(dag.items()):
+ for dep in deps:
+ if not dep in dag:
+ dag[dep] = set()
+
+ while dag:
+ current = {node for node, deps in dag.items() if not deps}
- visited.add(node)
+ if not current:
+ raise ValueError("Cycle error")
- for dep in dag.get(node, []):
- visit(dep, visited)
+ for node in current:
+ dag.pop(node)
- visited.remove(node)
- result.append(node)
+ for deps in dag.values():
+ deps -= current
- for node in dag:
- visit(node, set())
+ result.extend(sorted(current)) # type: ignore
return result
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 6238759..39e2c53 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,13 +1,25 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import exp
+from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.schema import ensure_schema
+from sqlglot.schema import Schema, ensure_schema
+
+if t.TYPE_CHECKING:
+ B = t.TypeVar("B", bound=exp.Binary)
-def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
+def annotate_types(
+ expression: E,
+ schema: t.Optional[t.Dict | Schema] = None,
+ annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+) -> E:
"""
- Recursively infer & annotate types in an expression syntax tree against a schema.
- Assumes that we've already executed the optimizer's qualify_columns step.
+ Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot
@@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
<Type.DOUBLE: 'DOUBLE'>
Args:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
+ expression: Expression to annotate.
+ schema: Database schema.
+ annotators: Maps expression type to corresponding annotation function.
+ coerces_to: Maps expression type to set of types that it can be coerced into.
+
Returns:
- sqlglot.Expression: expression annotated with types
+ The expression annotated with types.
"""
schema = ensure_schema(schema)
@@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
-class TypeAnnotator:
- ANNOTATORS = {
- **{
- expr_type: lambda self, expr: self._annotate_unary(expr)
- for expr_type in subclasses(exp.__name__, exp.Unary)
- },
- **{
- expr_type: lambda self, expr: self._annotate_binary(expr)
- for expr_type in subclasses(exp.__name__, exp.Binary)
- },
- exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
- exp.Alias: lambda self, expr: self._annotate_unary(expr),
- exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.Literal: lambda self, expr: self._annotate_literal(expr),
- exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
- exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
- exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.BIGINT
- ),
- exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.Sum: lambda self, expr: self._annotate_by_args(
- expr, "this", "expressions", promote=True
- ),
- exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.CurrentTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.TimestampSub: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATE
- ),
- exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
- exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
- exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
- exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.GroupConcat: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
- exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
- exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
- exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
- exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DOUBLE
- ),
- exp.RegexpLike: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.BOOLEAN
- ),
- exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.StrToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATE
- ),
- exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.UnixToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.VariancePop: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DOUBLE
- ),
- exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- }
+def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
+ return lambda self, e: self._annotate_with_type(e, data_type)
- # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
- COERCES_TO = {
- # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
- exp.DataType.Type.TEXT: set(),
- exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
- exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
- exp.DataType.Type.NCHAR: {
- exp.DataType.Type.VARCHAR,
- exp.DataType.Type.NVARCHAR,
+
+class _TypeAnnotator(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+
+ # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
+ # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
+ text_precedence = (
exp.DataType.Type.TEXT,
- },
- exp.DataType.Type.CHAR: {
- exp.DataType.Type.NCHAR,
- exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
- exp.DataType.Type.TEXT,
- },
- # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
- exp.DataType.Type.DOUBLE: set(),
- exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
- exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
- exp.DataType.Type.BIGINT: {
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NCHAR,
+ exp.DataType.Type.CHAR,
+ )
+ numeric_precedence = (
exp.DataType.Type.DOUBLE,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.BIGINT,
+ exp.DataType.Type.INT,
+ exp.DataType.Type.SMALLINT,
+ exp.DataType.Type.TINYINT,
+ )
+ timelike_precedence = (
+ exp.DataType.Type.TIMESTAMPLTZ,
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMP,
+ exp.DataType.Type.DATETIME,
+ exp.DataType.Type.DATE,
+ )
+
+ for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
+ coerces_to = set()
+ for data_type in type_precedence:
+ klass.COERCES_TO[data_type] = coerces_to.copy()
+ coerces_to |= {data_type}
+
+ return klass
+
+
+class TypeAnnotator(metaclass=_TypeAnnotator):
+ TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
+ exp.DataType.Type.BIGINT: {
+ exp.ApproxDistinct,
+ exp.ArraySize,
+ exp.Count,
+ exp.Length,
+ },
+ exp.DataType.Type.BOOLEAN: {
+ exp.Between,
+ exp.Boolean,
+ exp.In,
+ exp.RegexpLike,
+ },
+ exp.DataType.Type.DATE: {
+ exp.CurrentDate,
+ exp.Date,
+ exp.DateAdd,
+ exp.DateStrToDate,
+ exp.DateSub,
+ exp.DateTrunc,
+ exp.DiToDate,
+ exp.StrToDate,
+ exp.TimeStrToDate,
+ exp.TsOrDsToDate,
+ },
+ exp.DataType.Type.DATETIME: {
+ exp.CurrentDatetime,
+ exp.DatetimeAdd,
+ exp.DatetimeSub,
+ },
+ exp.DataType.Type.DOUBLE: {
+ exp.ApproxQuantile,
+ exp.Avg,
+ exp.Exp,
+ exp.Ln,
+ exp.Log,
+ exp.Log2,
+ exp.Log10,
+ exp.Pow,
+ exp.Quantile,
+ exp.Round,
+ exp.SafeDivide,
+ exp.Sqrt,
+ exp.Stddev,
+ exp.StddevPop,
+ exp.StddevSamp,
+ exp.Variance,
+ exp.VariancePop,
},
exp.DataType.Type.INT: {
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.Ceil,
+ exp.DateDiff,
+ exp.DatetimeDiff,
+ exp.Extract,
+ exp.TimestampDiff,
+ exp.TimeDiff,
+ exp.DateToDi,
+ exp.Floor,
+ exp.Levenshtein,
+ exp.StrPosition,
+ exp.TsOrDiToDi,
},
- exp.DataType.Type.SMALLINT: {
- exp.DataType.Type.INT,
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.DataType.Type.TIMESTAMP: {
+ exp.CurrentTime,
+ exp.CurrentTimestamp,
+ exp.StrToTime,
+ exp.TimeAdd,
+ exp.TimeStrToTime,
+ exp.TimeSub,
+ exp.TimestampAdd,
+ exp.TimestampSub,
+ exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
- exp.DataType.Type.SMALLINT,
- exp.DataType.Type.INT,
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.Day,
+ exp.Month,
+ exp.Week,
+ exp.Year,
},
- # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
- exp.DataType.Type.TIMESTAMPLTZ: set(),
- exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
- exp.DataType.Type.TIMESTAMP: {
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ exp.DataType.Type.VARCHAR: {
+ exp.ArrayConcat,
+ exp.Concat,
+ exp.ConcatWs,
+ exp.DateToDateStr,
+ exp.GroupConcat,
+ exp.Initcap,
+ exp.Lower,
+ exp.SafeConcat,
+ exp.Substring,
+ exp.TimeToStr,
+ exp.TimeToTimeStr,
+ exp.Trim,
+ exp.TsOrDsToDateStr,
+ exp.UnixToStr,
+ exp.UnixToTimeStr,
+ exp.Upper,
},
- exp.DataType.Type.DATETIME: {
- exp.DataType.Type.TIMESTAMP,
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ }
+
+ ANNOTATORS = {
+ **{
+ expr_type: lambda self, e: self._annotate_unary(e)
+ for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
- exp.DataType.Type.DATE: {
- exp.DataType.Type.DATETIME,
- exp.DataType.Type.TIMESTAMP,
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ **{
+ expr_type: lambda self, e: self._annotate_binary(e)
+ for expr_type in subclasses(exp.__name__, exp.Binary)
+ },
+ **{
+ expr_type: _annotate_with_type_lambda(data_type)
+ for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
+ for expr_type in expressions
},
+ exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
+ exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
+ exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
+ exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
+ exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
+ exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Literal: lambda self, e: self._annotate_literal(e),
+ exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
+ exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
- TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
+ # Specifies what types a given type can be coerced into (autofilled)
+ COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
- def __init__(self, schema=None, annotators=None, coerces_to=None):
+ def __init__(
+ self,
+ schema: Schema,
+ annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+ ) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
- def annotate(self, expression):
- if isinstance(expression, self.TRAVERSABLES):
- for scope in traverse_scope(expression):
- selects = {}
- for name, source in scope.sources.items():
- if not isinstance(source, Scope):
- continue
- if isinstance(source.expression, exp.UDTF):
- values = []
-
- if isinstance(source.expression, exp.Lateral):
- if isinstance(source.expression.this, exp.Explode):
- values = [source.expression.this.this]
- else:
- values = source.expression.expressions[0].expressions
-
- if not values:
- continue
-
- selects[name] = {
- alias: column
- for alias, column in zip(
- source.expression.alias_column_names,
- values,
- )
- }
+ def annotate(self, expression: E) -> E:
+ for scope in traverse_scope(expression):
+ selects = {}
+ for name, source in scope.sources.items():
+ if not isinstance(source, Scope):
+ continue
+ if isinstance(source.expression, exp.UDTF):
+ values = []
+
+ if isinstance(source.expression, exp.Lateral):
+ if isinstance(source.expression.this, exp.Explode):
+ values = [source.expression.this.this]
else:
- selects[name] = {
- select.alias_or_name: select for select in source.expression.selects
- }
- # First annotate the current scope's column references
- for col in scope.columns:
- if not col.table:
+ values = source.expression.expressions[0].expressions
+
+ if not values:
continue
- source = scope.sources.get(col.table)
- if isinstance(source, exp.Table):
- col.type = self.schema.get_column_type(source, col)
- elif source and col.table in selects and col.name in selects[col.table]:
- col.type = selects[col.table][col.name].type
- # Then (possibly) annotate the remaining expressions in the scope
- self._maybe_annotate(scope.expression)
+ selects[name] = {
+ alias: column
+ for alias, column in zip(
+ source.expression.alias_column_names,
+ values,
+ )
+ }
+ else:
+ selects[name] = {
+ select.alias_or_name: select for select in source.expression.selects
+ }
+
+ # First annotate the current scope's column references
+ for col in scope.columns:
+ if not col.table:
+ continue
+
+ source = scope.sources.get(col.table)
+ if isinstance(source, exp.Table):
+ col.type = self.schema.get_column_type(source, col)
+ elif source and col.table in selects and col.name in selects[col.table]:
+ col.type = selects[col.table][col.name].type
+
+ # Then (possibly) annotate the remaining expressions in the scope
+ self._maybe_annotate(scope.expression)
+
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
- def _maybe_annotate(self, expression):
+ def _maybe_annotate(self, expression: E) -> E:
if expression.type:
return expression # We've already inferred the expression's type
@@ -312,13 +290,15 @@ class TypeAnnotator:
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
- def _annotate_args(self, expression):
+ def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
- def _maybe_coerce(self, type1, type2):
+ def _maybe_coerce(
+ self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
+ ) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
@@ -330,9 +310,14 @@ class TypeAnnotator:
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to.get(type1, {}) else type1
+ return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
- def _annotate_binary(self, expression):
+ # Note: the following "no_type_check" decorators were added because mypy was yelling due
+ # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
+ # This is a known mypy issue: https://github.com/python/mypy/issues/3004
+
+ @t.no_type_check
+ def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left_type = expression.left.type.this
@@ -354,7 +339,8 @@ class TypeAnnotator:
return expression
- def _annotate_unary(self, expression):
+ @t.no_type_check
+ def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
@@ -364,7 +350,8 @@ class TypeAnnotator:
return expression
- def _annotate_literal(self, expression):
+ @t.no_type_check
+ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
@@ -374,13 +361,16 @@ class TypeAnnotator:
return expression
- def _annotate_with_type(self, expression, target_type):
+ @t.no_type_check
+ def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
expression.type = target_type
return self._annotate_args(expression)
- def _annotate_by_args(self, expression, *args, promote=False):
+ @t.no_type_check
+ def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
self._annotate_args(expression)
- expressions = []
+
+ expressions: t.List[exp.Expression] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index da2fce8..015b06a 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
- node = exp.Concat(this=node.this, expression=node.expression)
+ node = exp.Concat(expressions=[node.left, node.right])
return node
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 27de9c7..cd8ba3b 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -32,7 +32,7 @@ def eliminate_joins(expression):
# Reverse the joins so we can remove chains of unused joins
for join in reversed(joins):
- alias = join.this.alias_or_name
+ alias = join.alias_or_name
if _should_eliminate_join(scope, join, alias):
join.pop()
scope.remove_source(alias)
@@ -126,7 +126,7 @@ def join_condition(join):
tuple[list[str], list[str], exp.Expression]:
Tuple of (source key, join key, remaining predicate)
"""
- name = join.this.alias_or_name
+ name = join.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
source_key = []
join_key = []
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index 5dfa4aa..79e3ed5 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
- alias(source, source.name or source.alias, table=True),
+ alias(source, source.alias_or_name, table=True),
copy=False,
)
.subquery(source.alias, copy=False)
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index f9c9664..fefe96e 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
if not isinstance(from_or_join, exp.Join):
return False
- alias = from_or_join.this.alias_or_name
+ alias = from_or_join.alias_or_name
on = from_or_join.args.get("on")
if not on:
@@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
"""
new_joins = []
- comma_joins = inner_scope.expression.args.get("from").expressions[1:]
- for subquery in comma_joins:
- new_joins.append(exp.Join(this=subquery, kind="CROSS"))
- outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
joins = inner_scope.expression.args.get("joins") or []
for join in joins:
@@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if source == from_or_join.alias_or_name:
break
- if set(exp.column_table_names(where.this)) <= sources:
+ if exp.column_table_names(where.this) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
- expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 4e0c3a1..d51276f 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,3 +1,7 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import exp
from sqlglot.helper import tsort
@@ -13,25 +17,28 @@ def optimize_joins(expression):
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
"""
+
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
- name = join.this.alias_or_name
- tables = other_table_names(join, name)
+ tables = other_table_names(join)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
- cross_joins.append((name, join))
+ cross_joins.append((join.alias_or_name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
if isinstance(on, exp.Connector):
+ if len(other_table_names(dep)) < 2:
+ continue
+
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
@@ -47,17 +54,12 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
- head = from_.this
parent = from_.parent
- joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
- dag = {head.alias_or_name: []}
-
- for name, join in joins.items():
- dag[name] = other_table_names(join, name)
-
+ joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
+ dag = {name: other_table_names(join) for name, join in joins.items()}
parent.set(
"joins",
- [joins[name] for name in tsort(dag) if name != head.alias_or_name],
+ [joins[name] for name in tsort(dag) if name != from_.alias_or_name],
)
return expression
@@ -75,9 +77,6 @@ def normalize(expression):
return expression
-def other_table_names(join, exclude):
- return [
- name
- for name in (exp.column_table_names(join.args.get("on") or exp.true()))
- if name != exclude
- ]
+def other_table_names(join: exp.Join) -> t.Set[str]:
+ on = join.args.get("on")
+ return exp.column_table_names(on, join.alias_or_name) if on else set()
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index dbe33a2..abac63b 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -78,7 +78,7 @@ def optimize(
"schema": schema,
"dialect": dialect,
"isolate_tables": True, # needed for other optimizations to perform well
- "quote_identifiers": False, # this happens in canonicalize
+ "quote_identifiers": False,
**kwargs,
}
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index b89a82b..fb1662d 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -41,7 +41,7 @@ def pushdown_predicates(expression):
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
- name = join.this.alias_or_name
+ name = join.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
@@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
pushdown_tables = set()
for a in predicates:
- a_tables = set(exp.column_table_names(a))
+ a_tables = exp.column_table_names(a)
for b in predicates:
- a_tables &= set(exp.column_table_names(b))
+ a_tables &= exp.column_table_names(b)
pushdown_tables.update(a_tables)
@@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
- for table in tables:
+ for table in sorted(tables):
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 4a31171..aba9a7e 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema
def qualify_columns(
expression: exp.Expression,
- schema: dict | Schema,
+ schema: t.Dict | Schema,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
@@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
joins = list(scope.find_all(exp.Join))
- names = {join.this.alias for join in joins}
+ names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to an ordered set of source names (dict).
@@ -105,7 +105,7 @@ def _expand_using(scope, resolver):
if not using:
continue
- join_table = join.this.alias_or_name
+ join_table = join.alias_or_name
columns = {}
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index fcc5f26..9c931d6 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -91,11 +91,13 @@ def qualify_tables(
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
- table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
+ table_alias = udtf.args.get("alias") or exp.TableAlias(
+ this=exp.to_identifier(next_alias_name())
+ )
udtf.set("alias", table_alias)
if not table_alias.name:
- table_alias.set("this", next_alias_name())
+ table_alias.set("this", exp.to_identifier(next_alias_name()))
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 9ffb4d6..aa56b83 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -620,7 +620,7 @@ def _traverse_tables(scope):
table_name = expression.name
source_name = expression.alias_or_name
- if table_name in scope.sources:
+ if table_name in scope.sources and not expression.db:
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 96bd6e3..d6888c7 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -6,7 +6,8 @@ from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
+from sqlglot.helper import apply_index_offset, ensure_list, seq_get
+from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@@ -25,13 +26,14 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
for i in range(0, len(args), 2):
keys.append(args[i])
values.append(args[i + 1])
+
return exp.VarMap(
keys=exp.Array(expressions=keys),
values=exp.Array(expressions=values),
)
-def parse_like(args: t.List) -> exp.Expression:
+def parse_like(args: t.List) -> exp.Escape | exp.Like:
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
@@ -47,33 +49,26 @@ def binary_range_parser(
class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
- klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
- klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
+
+ klass.SHOW_TRIE = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
+ klass.SET_TRIE = new_trie(key.split(" ") for key in klass.SET_PARSERS)
return klass
class Parser(metaclass=_Parser):
"""
- Parser consumes a list of tokens produced by the `sqlglot.tokens.Tokenizer` and produces
- a parsed syntax tree.
+ Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree.
Args:
- error_level: the desired error level.
+ error_level: The desired error level.
Default: ErrorLevel.IMMEDIATE
- error_message_context: determines the amount of context to capture from a
+ error_message_context: Determines the amount of context to capture from a
query string when displaying the error message (in number of characters).
- Default: 50.
- index_offset: Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list.
- Default: 0
- alias_post_tablesample: If the table alias comes after tablesample.
- Default: False
+ Default: 100
max_errors: Maximum number of error messages to include in a raised ParseError.
This is only relevant if error_level is ErrorLevel.RAISE.
Default: 3
- null_ordering: Indicates the default null ordering method to use if not explicitly set.
- Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
- Default: "nulls_are_small"
"""
FUNCTIONS: t.Dict[str, t.Callable] = {
@@ -83,7 +78,6 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
- "IFNULL": exp.Coalesce.from_arg_list,
"LIKE": parse_like,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
@@ -108,8 +102,6 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_USER: exp.CurrentUser,
}
- JOIN_HINTS: t.Set[str] = set()
-
NESTED_TYPE_TOKENS = {
TokenType.ARRAY,
TokenType.MAP,
@@ -117,6 +109,10 @@ class Parser(metaclass=_Parser):
TokenType.STRUCT,
}
+ ENUM_TYPE_TOKENS = {
+ TokenType.ENUM,
+ }
+
TYPE_TOKENS = {
TokenType.BIT,
TokenType.BOOLEAN,
@@ -188,6 +184,7 @@ class Parser(metaclass=_Parser):
TokenType.VARIANT,
TokenType.OBJECT,
TokenType.INET,
+ TokenType.ENUM,
*NESTED_TYPE_TOKENS,
}
@@ -198,7 +195,10 @@ class Parser(metaclass=_Parser):
TokenType.SOME: exp.Any,
}
- RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT}
+ RESERVED_KEYWORDS = {
+ *Tokenizer.SINGLE_TOKENS.values(),
+ TokenType.SELECT,
+ }
DB_CREATABLES = {
TokenType.DATABASE,
@@ -216,6 +216,7 @@ class Parser(metaclass=_Parser):
*DB_CREATABLES,
}
+ # Tokens that can represent identifiers
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ANTI,
@@ -224,6 +225,7 @@ class Parser(metaclass=_Parser):
TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.CACHE,
+ TokenType.CASE,
TokenType.COLLATE,
TokenType.COMMAND,
TokenType.COMMENT,
@@ -274,6 +276,7 @@ class Parser(metaclass=_Parser):
TokenType.TRUE,
TokenType.UNIQUE,
TokenType.UNPIVOT,
+ TokenType.UPDATE,
TokenType.VOLATILE,
TokenType.WINDOW,
*CREATABLES,
@@ -409,6 +412,8 @@ class Parser(metaclass=_Parser):
TokenType.ANTI,
}
+ JOIN_HINTS: t.Set[str] = set()
+
LAMBDAS = {
TokenType.ARROW: lambda self, expressions: self.expression(
exp.Lambda,
@@ -420,7 +425,7 @@ class Parser(metaclass=_Parser):
),
TokenType.FARROW: lambda self, expressions: self.expression(
exp.Kwarg,
- this=exp.Var(this=expressions[0].name),
+ this=exp.var(expressions[0].name),
expression=self._parse_conjunction(),
),
}
@@ -515,7 +520,7 @@ class Parser(metaclass=_Parser):
TokenType.USE: lambda self: self.expression(
exp.Use,
kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"))
- and exp.Var(this=self._prev.text),
+ and exp.var(self._prev.text),
this=self._parse_table(schema=False),
),
}
@@ -634,6 +639,7 @@ class Parser(metaclass=_Parser):
"TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property),
"TEMP": lambda self: self.expression(exp.TemporaryProperty),
"TEMPORARY": lambda self: self.expression(exp.TemporaryProperty),
+ "TO": lambda self: self._parse_to_table(),
"TRANSIENT": lambda self: self.expression(exp.TransientProperty),
"TTL": lambda self: self._parse_ttl(),
"USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
@@ -710,6 +716,7 @@ class Parser(metaclass=_Parser):
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
+ "CONCAT": lambda self: self._parse_concat(),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
@@ -755,8 +762,11 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
- TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+ DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN}
+ PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE}
+
+ TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
TRANSACTION_CHARACTERISTICS = {
"ISOLATION LEVEL REPEATABLE READ",
"ISOLATION LEVEL READ COMMITTED",
@@ -778,6 +788,8 @@ class Parser(metaclass=_Parser):
STRICT_CAST = True
+ CONCAT_NULL_OUTPUTS_STRING = False # A NULL arg in CONCAT yields NULL by default
+
CONVERT_TYPE_FIRST = False
PREFIXED_PIVOT_COLUMNS = False
@@ -789,40 +801,39 @@ class Parser(metaclass=_Parser):
__slots__ = (
"error_level",
"error_message_context",
+ "max_errors",
"sql",
"errors",
- "index_offset",
- "unnest_column_only",
- "alias_post_tablesample",
- "max_errors",
- "null_ordering",
"_tokens",
"_index",
"_curr",
"_next",
"_prev",
"_prev_comments",
- "_show_trie",
- "_set_trie",
)
+ # Autofilled
+ INDEX_OFFSET: int = 0
+ UNNEST_COLUMN_ONLY: bool = False
+ ALIAS_POST_TABLESAMPLE: bool = False
+ STRICT_STRING_CONCAT = False
+ NULL_ORDERING: str = "nulls_are_small"
+ SHOW_TRIE: t.Dict = {}
+ SET_TRIE: t.Dict = {}
+ FORMAT_MAPPING: t.Dict[str, str] = {}
+ FORMAT_TRIE: t.Dict = {}
+ TIME_MAPPING: t.Dict[str, str] = {}
+ TIME_TRIE: t.Dict = {}
+
def __init__(
self,
error_level: t.Optional[ErrorLevel] = None,
error_message_context: int = 100,
- index_offset: int = 0,
- unnest_column_only: bool = False,
- alias_post_tablesample: bool = False,
max_errors: int = 3,
- null_ordering: t.Optional[str] = None,
):
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
- self.index_offset = index_offset
- self.unnest_column_only = unnest_column_only
- self.alias_post_tablesample = alias_post_tablesample
self.max_errors = max_errors
- self.null_ordering = null_ordering
self.reset()
def reset(self):
@@ -843,11 +854,11 @@ class Parser(metaclass=_Parser):
per parsed SQL statement.
Args:
- raw_tokens: the list of tokens.
- sql: the original SQL string, used to produce helpful debug messages.
+ raw_tokens: The list of tokens.
+ sql: The original SQL string, used to produce helpful debug messages.
Returns:
- The list of syntax trees.
+ The list of the produced syntax trees.
"""
return self._parse(
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
@@ -865,23 +876,25 @@ class Parser(metaclass=_Parser):
of them, stopping at the first for which the parsing succeeds.
Args:
- expression_types: the expression type(s) to try and parse the token list into.
- raw_tokens: the list of tokens.
- sql: the original SQL string, used to produce helpful debug messages.
+ expression_types: The expression type(s) to try and parse the token list into.
+ raw_tokens: The list of tokens.
+ sql: The original SQL string, used to produce helpful debug messages.
Returns:
The target Expression.
"""
errors = []
- for expression_type in ensure_collection(expression_types):
+ for expression_type in ensure_list(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
raise TypeError(f"No parser registered for {expression_type}")
+
try:
return self._parse(parser, raw_tokens, sql)
except ParseError as e:
e.errors[0]["into_expression"] = expression_type
errors.append(e)
+
raise ParseError(
f"Failed to parse '{sql or raw_tokens}' into {expression_types}",
errors=merge_errors(errors),
@@ -895,6 +908,7 @@ class Parser(metaclass=_Parser):
) -> t.List[t.Optional[exp.Expression]]:
self.reset()
self.sql = sql or ""
+
total = len(raw_tokens)
chunks: t.List[t.List[Token]] = [[]]
@@ -922,9 +936,7 @@ class Parser(metaclass=_Parser):
return expressions
def check_errors(self) -> None:
- """
- Logs or raises any found errors, depending on the chosen error level setting.
- """
+ """Logs or raises any found errors, depending on the chosen error level setting."""
if self.error_level == ErrorLevel.WARN:
for error in self.errors:
logger.error(str(error))
@@ -969,39 +981,38 @@ class Parser(metaclass=_Parser):
Creates a new, validated Expression.
Args:
- exp_class: the expression class to instantiate.
- comments: an optional list of comments to attach to the expression.
- kwargs: the arguments to set for the expression along with their respective values.
+ exp_class: The expression class to instantiate.
+ comments: An optional list of comments to attach to the expression.
+ kwargs: The arguments to set for the expression along with their respective values.
Returns:
The target expression.
"""
instance = exp_class(**kwargs)
instance.add_comments(comments) if comments else self._add_comments(instance)
- self.validate_expression(instance)
- return instance
+ return self.validate_expression(instance)
def _add_comments(self, expression: t.Optional[exp.Expression]) -> None:
if expression and self._prev_comments:
expression.add_comments(self._prev_comments)
self._prev_comments = None
- def validate_expression(
- self, expression: exp.Expression, args: t.Optional[t.List] = None
- ) -> None:
+ def validate_expression(self, expression: E, args: t.Optional[t.List] = None) -> E:
"""
- Validates an already instantiated expression, making sure that all its mandatory arguments
- are set.
+ Validates an Expression, making sure that all its mandatory arguments are set.
Args:
- expression: the expression to validate.
- args: an optional list of items that was used to instantiate the expression, if it's a Func.
+ expression: The expression to validate.
+ args: An optional list of items that was used to instantiate the expression, if it's a Func.
+
+ Returns:
+ The validated expression.
"""
- if self.error_level == ErrorLevel.IGNORE:
- return
+ if self.error_level != ErrorLevel.IGNORE:
+ for error_message in expression.error_messages(args):
+ self.raise_error(error_message)
- for error_message in expression.error_messages(args):
- self.raise_error(error_message)
+ return expression
def _find_sql(self, start: Token, end: Token) -> str:
return self.sql[start.start : end.end + 1]
@@ -1010,6 +1021,7 @@ class Parser(metaclass=_Parser):
self._index += times
self._curr = seq_get(self._tokens, self._index)
self._next = seq_get(self._tokens, self._index + 1)
+
if self._index > 0:
self._prev = self._tokens[self._index - 1]
self._prev_comments = self._prev.comments
@@ -1031,7 +1043,6 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ON)
kind = self._match_set(self.CREATABLES) and self._prev
-
if not kind:
return self._parse_as_command(start)
@@ -1050,6 +1061,12 @@ class Parser(metaclass=_Parser):
exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
)
+ def _parse_to_table(
+ self,
+ ) -> exp.ToTableProperty:
+ table = self._parse_table_parts(schema=True)
+ return self.expression(exp.ToTableProperty, this=table)
+
# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl
def _parse_ttl(self) -> exp.Expression:
def _parse_ttl_action() -> t.Optional[exp.Expression]:
@@ -1102,10 +1119,11 @@ class Parser(metaclass=_Parser):
expression = self._parse_set_operations(expression) if expression else self._parse_select()
return self._parse_query_modifiers(expression)
- def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]:
+ def _parse_drop(self) -> exp.Drop | exp.Command:
start = self._prev
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match_text_seq("MATERIALIZED")
+
kind = self._match_set(self.CREATABLES) and self._prev.text
if not kind:
return self._parse_as_command(start)
@@ -1129,21 +1147,23 @@ class Parser(metaclass=_Parser):
and self._match(TokenType.EXISTS)
)
- def _parse_create(self) -> t.Optional[exp.Expression]:
+ def _parse_create(self) -> exp.Create | exp.Command:
+ # Note: this can't be None because we've matched a statement parser
start = self._prev
- replace = self._prev.text.upper() == "REPLACE" or self._match_pair(
+ replace = start.text.upper() == "REPLACE" or self._match_pair(
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
- self._match(TokenType.TABLE)
+ self._advance()
properties = None
create_token = self._match_set(self.CREATABLES) and self._prev
if not create_token:
- properties = self._parse_properties() # exp.Properties.Location.POST_CREATE
+ # exp.Properties.Location.POST_CREATE
+ properties = self._parse_properties()
create_token = self._match_set(self.CREATABLES) and self._prev
if not properties or not create_token:
@@ -1157,7 +1177,7 @@ class Parser(metaclass=_Parser):
begin = None
clone = None
- def extend_props(temp_props: t.Optional[exp.Expression]) -> None:
+ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
nonlocal properties
if properties and temp_props:
properties.expressions.extend(temp_props.expressions)
@@ -1166,6 +1186,8 @@ class Parser(metaclass=_Parser):
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
+
+ # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature)
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
@@ -1190,13 +1212,8 @@ class Parser(metaclass=_Parser):
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
-
- # exp.Properties.Location.POST_ALIAS
- if not (
- self._match(TokenType.SELECT, advance=False)
- or self._match(TokenType.WITH, advance=False)
- or self._match(TokenType.L_PAREN, advance=False)
- ):
+ if not self._match_set(self.DDL_SELECT_TOKENS, advance=False):
+ # exp.Properties.Location.POST_ALIAS
extend_props(self._parse_properties())
expression = self._parse_ddl_select()
@@ -1206,7 +1223,7 @@ class Parser(metaclass=_Parser):
while True:
index = self._parse_index()
- # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
+ # exp.Properties.Location.POST_EXPRESSION and POST_INDEX
extend_props(self._parse_properties())
if not index:
@@ -1296,7 +1313,7 @@ class Parser(metaclass=_Parser):
return None
- def _parse_stored(self) -> exp.Expression:
+ def _parse_stored(self) -> exp.FileFormatProperty:
self._match(TokenType.ALIAS)
input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None
@@ -1311,14 +1328,13 @@ class Parser(metaclass=_Parser):
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
- def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
+ def _parse_property_assignment(self, exp_class: t.Type[E]) -> E:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(exp_class, this=self._parse_field())
- def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]:
+ def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
properties = []
-
while True:
if before:
prop = self._parse_property_before()
@@ -1335,29 +1351,25 @@ class Parser(metaclass=_Parser):
return None
- def _parse_fallback(self, no: bool = False) -> exp.Expression:
+ def _parse_fallback(self, no: bool = False) -> exp.FallbackProperty:
return self.expression(
exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION")
)
- def _parse_volatile_property(self) -> exp.Expression:
+ def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty:
if self._index >= 2:
pre_volatile_token = self._tokens[self._index - 2]
else:
pre_volatile_token = None
- if pre_volatile_token and pre_volatile_token.token_type in (
- TokenType.CREATE,
- TokenType.REPLACE,
- TokenType.UNIQUE,
- ):
+ if pre_volatile_token and pre_volatile_token.token_type in self.PRE_VOLATILE_TOKENS:
return exp.VolatileProperty()
return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE"))
def _parse_with_property(
self,
- ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]:
+ ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]:
self._match(TokenType.WITH)
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_property)
@@ -1376,7 +1388,7 @@ class Parser(metaclass=_Parser):
return self._parse_withisolatedloading()
# https://dev.mysql.com/doc/refman/8.0/en/create-view.html
- def _parse_definer(self) -> t.Optional[exp.Expression]:
+ def _parse_definer(self) -> t.Optional[exp.DefinerProperty]:
self._match(TokenType.EQ)
user = self._parse_id_var()
@@ -1388,18 +1400,18 @@ class Parser(metaclass=_Parser):
return exp.DefinerProperty(this=f"{user}@{host}")
- def _parse_withjournaltable(self) -> exp.Expression:
+ def _parse_withjournaltable(self) -> exp.WithJournalTableProperty:
self._match(TokenType.TABLE)
self._match(TokenType.EQ)
return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts())
- def _parse_log(self, no: bool = False) -> exp.Expression:
+ def _parse_log(self, no: bool = False) -> exp.LogProperty:
return self.expression(exp.LogProperty, no=no)
- def _parse_journal(self, **kwargs) -> exp.Expression:
+ def _parse_journal(self, **kwargs) -> exp.JournalProperty:
return self.expression(exp.JournalProperty, **kwargs)
- def _parse_checksum(self) -> exp.Expression:
+ def _parse_checksum(self) -> exp.ChecksumProperty:
self._match(TokenType.EQ)
on = None
@@ -1407,53 +1419,47 @@ class Parser(metaclass=_Parser):
on = True
elif self._match_text_seq("OFF"):
on = False
- default = self._match(TokenType.DEFAULT)
- return self.expression(
- exp.ChecksumProperty,
- on=on,
- default=default,
- )
+ return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT))
- def _parse_cluster(self) -> t.Optional[exp.Expression]:
+ def _parse_cluster(self) -> t.Optional[exp.Cluster]:
if not self._match_text_seq("BY"):
self._retreat(self._index - 1)
return None
- return self.expression(
- exp.Cluster,
- expressions=self._parse_csv(self._parse_ordered),
- )
- def _parse_freespace(self) -> exp.Expression:
+ return self.expression(exp.Cluster, expressions=self._parse_csv(self._parse_ordered))
+
+ def _parse_freespace(self) -> exp.FreespaceProperty:
self._match(TokenType.EQ)
return self.expression(
exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT)
)
- def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression:
+ def _parse_mergeblockratio(
+ self, no: bool = False, default: bool = False
+ ) -> exp.MergeBlockRatioProperty:
if self._match(TokenType.EQ):
return self.expression(
exp.MergeBlockRatioProperty,
this=self._parse_number(),
percent=self._match(TokenType.PERCENT),
)
- return self.expression(
- exp.MergeBlockRatioProperty,
- no=no,
- default=default,
- )
+
+ return self.expression(exp.MergeBlockRatioProperty, no=no, default=default)
def _parse_datablocksize(
self,
default: t.Optional[bool] = None,
minimum: t.Optional[bool] = None,
maximum: t.Optional[bool] = None,
- ) -> exp.Expression:
+ ) -> exp.DataBlocksizeProperty:
self._match(TokenType.EQ)
size = self._parse_number()
+
units = None
if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")):
units = self._prev.text
+
return self.expression(
exp.DataBlocksizeProperty,
size=size,
@@ -1463,12 +1469,13 @@ class Parser(metaclass=_Parser):
maximum=maximum,
)
- def _parse_blockcompression(self) -> exp.Expression:
+ def _parse_blockcompression(self) -> exp.BlockCompressionProperty:
self._match(TokenType.EQ)
always = self._match_text_seq("ALWAYS")
manual = self._match_text_seq("MANUAL")
never = self._match_text_seq("NEVER")
default = self._match_text_seq("DEFAULT")
+
autotemp = None
if self._match_text_seq("AUTOTEMP"):
autotemp = self._parse_schema()
@@ -1482,7 +1489,7 @@ class Parser(metaclass=_Parser):
autotemp=autotemp,
)
- def _parse_withisolatedloading(self) -> exp.Expression:
+ def _parse_withisolatedloading(self) -> exp.IsolatedLoadingProperty:
no = self._match_text_seq("NO")
concurrent = self._match_text_seq("CONCURRENT")
self._match_text_seq("ISOLATED", "LOADING")
@@ -1498,7 +1505,7 @@ class Parser(metaclass=_Parser):
for_none=for_none,
)
- def _parse_locking(self) -> exp.Expression:
+ def _parse_locking(self) -> exp.LockingProperty:
if self._match(TokenType.TABLE):
kind = "TABLE"
elif self._match(TokenType.VIEW):
@@ -1553,14 +1560,14 @@ class Parser(metaclass=_Parser):
return self._parse_csv(self._parse_conjunction)
return []
- def _parse_partitioned_by(self) -> exp.Expression:
+ def _parse_partitioned_by(self) -> exp.PartitionedByProperty:
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
- def _parse_withdata(self, no: bool = False) -> exp.Expression:
+ def _parse_withdata(self, no: bool = False) -> exp.WithDataProperty:
if self._match_text_seq("AND", "STATISTICS"):
statistics = True
elif self._match_text_seq("AND", "NO", "STATISTICS"):
@@ -1570,52 +1577,50 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
- def _parse_no_property(self) -> t.Optional[exp.Property]:
+ def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]:
if self._match_text_seq("PRIMARY", "INDEX"):
return exp.NoPrimaryIndexProperty()
return None
- def _parse_on_property(self) -> t.Optional[exp.Property]:
+ def _parse_on_property(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"):
return exp.OnCommitProperty()
elif self._match_text_seq("COMMIT", "DELETE", "ROWS"):
return exp.OnCommitProperty(delete=True)
return None
- def _parse_distkey(self) -> exp.Expression:
+ def _parse_distkey(self) -> exp.DistKeyProperty:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
- def _parse_create_like(self) -> t.Optional[exp.Expression]:
+ def _parse_create_like(self) -> t.Optional[exp.LikeProperty]:
table = self._parse_table(schema=True)
+
options = []
while self._match_texts(("INCLUDING", "EXCLUDING")):
this = self._prev.text.upper()
- id_var = self._parse_id_var()
+ id_var = self._parse_id_var()
if not id_var:
return None
options.append(
- self.expression(
- exp.Property,
- this=this,
- value=exp.Var(this=id_var.this.upper()),
- )
+ self.expression(exp.Property, this=this, value=exp.var(id_var.this.upper()))
)
+
return self.expression(exp.LikeProperty, this=table, expressions=options)
- def _parse_sortkey(self, compound: bool = False) -> exp.Expression:
+ def _parse_sortkey(self, compound: bool = False) -> exp.SortKeyProperty:
return self.expression(
- exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
+ exp.SortKeyProperty, this=self._parse_wrapped_id_vars(), compound=compound
)
- def _parse_character_set(self, default: bool = False) -> exp.Expression:
+ def _parse_character_set(self, default: bool = False) -> exp.CharacterSetProperty:
self._match(TokenType.EQ)
return self.expression(
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
- def _parse_returns(self) -> exp.Expression:
+ def _parse_returns(self) -> exp.ReturnsProperty:
value: t.Optional[exp.Expression]
is_table = self._match(TokenType.TABLE)
@@ -1629,19 +1634,18 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
- value = self._parse_schema(exp.Var(this="TABLE"))
+ value = self._parse_schema(exp.var("TABLE"))
else:
value = self._parse_types()
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
- def _parse_describe(self) -> exp.Expression:
+ def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
this = self._parse_table()
-
return self.expression(exp.Describe, this=this, kind=kind)
- def _parse_insert(self) -> exp.Expression:
+ def _parse_insert(self) -> exp.Insert:
overwrite = self._match(TokenType.OVERWRITE)
local = self._match_text_seq("LOCAL")
alternative = None
@@ -1673,11 +1677,11 @@ class Parser(metaclass=_Parser):
alternative=alternative,
)
- def _parse_on_conflict(self) -> t.Optional[exp.Expression]:
+ def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
conflict = self._match_text_seq("ON", "CONFLICT")
duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY")
- if not (conflict or duplicate):
+ if not conflict and not duplicate:
return None
nothing = None
@@ -1707,18 +1711,20 @@ class Parser(metaclass=_Parser):
constraint=constraint,
)
- def _parse_returning(self) -> t.Optional[exp.Expression]:
+ def _parse_returning(self) -> t.Optional[exp.Returning]:
if not self._match(TokenType.RETURNING):
return None
return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column))
- def _parse_row(self) -> t.Optional[exp.Expression]:
+ def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
if not self._match(TokenType.FORMAT):
return None
return self._parse_row_format()
- def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_row_format(
+ self, match_row: bool = False
+ ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]:
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
@@ -1744,7 +1750,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore
- def _parse_load(self) -> exp.Expression:
+ def _parse_load(self) -> exp.LoadData | exp.Command:
if self._match_text_seq("DATA"):
local = self._match_text_seq("LOCAL")
self._match_text_seq("INPATH")
@@ -1764,7 +1770,7 @@ class Parser(metaclass=_Parser):
)
return self._parse_as_command(self._prev)
- def _parse_delete(self) -> exp.Expression:
+ def _parse_delete(self) -> exp.Delete:
self._match(TokenType.FROM)
return self.expression(
@@ -1775,7 +1781,7 @@ class Parser(metaclass=_Parser):
returning=self._parse_returning(),
)
- def _parse_update(self) -> exp.Expression:
+ def _parse_update(self) -> exp.Update:
return self.expression(
exp.Update,
**{ # type: ignore
@@ -1787,22 +1793,20 @@ class Parser(metaclass=_Parser):
},
)
- def _parse_uncache(self) -> exp.Expression:
+ def _parse_uncache(self) -> exp.Uncache:
if not self._match(TokenType.TABLE):
self.raise_error("Expecting TABLE after UNCACHE")
return self.expression(
- exp.Uncache,
- exists=self._parse_exists(),
- this=self._parse_table(schema=True),
+ exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True)
)
- def _parse_cache(self) -> exp.Expression:
+ def _parse_cache(self) -> exp.Cache:
lazy = self._match_text_seq("LAZY")
self._match(TokenType.TABLE)
table = self._parse_table(schema=True)
- options = []
+ options = []
if self._match_text_seq("OPTIONS"):
self._match_l_paren()
k = self._parse_string()
@@ -1820,7 +1824,7 @@ class Parser(metaclass=_Parser):
expression=self._parse_select(nested=True),
)
- def _parse_partition(self) -> t.Optional[exp.Expression]:
+ def _parse_partition(self) -> t.Optional[exp.Partition]:
if not self._match(TokenType.PARTITION):
return None
@@ -1828,7 +1832,7 @@ class Parser(metaclass=_Parser):
exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction)
)
- def _parse_value(self) -> exp.Expression:
+ def _parse_value(self) -> exp.Tuple:
if self._match(TokenType.L_PAREN):
expressions = self._parse_csv(self._parse_conjunction)
self._match_r_paren()
@@ -1926,7 +1930,7 @@ class Parser(metaclass=_Parser):
return self._parse_set_operations(this)
- def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]:
if not skip_with_token and not self._match(TokenType.WITH):
return None
@@ -1946,22 +1950,19 @@ class Parser(metaclass=_Parser):
exp.With, comments=comments, expressions=expressions, recursive=recursive
)
- def _parse_cte(self) -> exp.Expression:
+ def _parse_cte(self) -> exp.CTE:
alias = self._parse_table_alias()
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
self._match(TokenType.ALIAS)
-
return self.expression(
- exp.CTE,
- this=self._parse_wrapped(self._parse_statement),
- alias=alias,
+ exp.CTE, this=self._parse_wrapped(self._parse_statement), alias=alias
)
def _parse_table_alias(
self, alias_tokens: t.Optional[t.Collection[TokenType]] = None
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.TableAlias]:
any_token = self._match(TokenType.ALIAS)
alias = (
self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
@@ -1982,9 +1983,10 @@ class Parser(metaclass=_Parser):
def _parse_subquery(
self, this: t.Optional[exp.Expression], parse_alias: bool = True
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Subquery]:
if not this:
return None
+
return self.expression(
exp.Subquery,
this=this,
@@ -2000,19 +2002,25 @@ class Parser(metaclass=_Parser):
expression = parser(self)
if expression:
+ if key == "limit":
+ offset = expression.args.pop("offset", None)
+ if offset:
+ this.set("offset", exp.Offset(expression=offset))
this.set(key, expression)
return this
- def _parse_hint(self) -> t.Optional[exp.Expression]:
+ def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
+
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
+
return self.expression(exp.Hint, expressions=hints)
return None
- def _parse_into(self) -> t.Optional[exp.Expression]:
+ def _parse_into(self) -> t.Optional[exp.Into]:
if not self._match(TokenType.INTO):
return None
@@ -2039,7 +2047,7 @@ class Parser(metaclass=_Parser):
this=self._parse_query_modifiers(this) if modifiers else this,
)
- def _parse_match_recognize(self) -> t.Optional[exp.Expression]:
+ def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]:
if not self._match(TokenType.MATCH_RECOGNIZE):
return None
@@ -2052,7 +2060,7 @@ class Parser(metaclass=_Parser):
)
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
- rows = exp.Var(this="ONE ROW PER MATCH")
+ rows = exp.var("ONE ROW PER MATCH")
elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"):
text = "ALL ROWS PER MATCH"
if self._match_text_seq("SHOW", "EMPTY", "MATCHES"):
@@ -2061,7 +2069,7 @@ class Parser(metaclass=_Parser):
text += f" OMIT EMPTY MATCHES"
elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"):
text += f" WITH UNMATCHED ROWS"
- rows = exp.Var(this=text)
+ rows = exp.var(text)
else:
rows = None
@@ -2075,7 +2083,7 @@ class Parser(metaclass=_Parser):
text += f" TO FIRST {self._advance_any().text}" # type: ignore
elif self._match_text_seq("TO", "LAST"):
text += f" TO LAST {self._advance_any().text}" # type: ignore
- after = exp.Var(this=text)
+ after = exp.var(text)
else:
after = None
@@ -2093,11 +2101,14 @@ class Parser(metaclass=_Parser):
paren += 1
if self._curr.token_type == TokenType.R_PAREN:
paren -= 1
+
end = self._prev
self._advance()
+
if paren > 0:
self.raise_error("Expecting )", self._curr)
- pattern = exp.Var(this=self._find_sql(start, end))
+
+ pattern = exp.var(self._find_sql(start, end))
else:
pattern = None
@@ -2127,7 +2138,7 @@ class Parser(metaclass=_Parser):
alias=self._parse_table_alias(),
)
- def _parse_lateral(self) -> t.Optional[exp.Expression]:
+ def _parse_lateral(self) -> t.Optional[exp.Lateral]:
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
@@ -2150,24 +2161,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
- table_alias: t.Optional[exp.Expression]
-
if view:
table = self._parse_id_var(any_token=False)
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
- table_alias = self.expression(exp.TableAlias, this=table, columns=columns)
+ table_alias: t.Optional[exp.TableAlias] = self.expression(
+ exp.TableAlias, this=table, columns=columns
+ )
+ elif isinstance(this, exp.Subquery) and this.alias:
+ # Ensures parity between the Subquery's and the Lateral's "alias" args
+ table_alias = this.args["alias"].copy()
else:
table_alias = self._parse_table_alias()
- expression = self.expression(
- exp.Lateral,
- this=this,
- view=view,
- outer=outer,
- alias=table_alias,
- )
-
- return expression
+ return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias)
def _parse_join_parts(
self,
@@ -2178,7 +2184,7 @@ class Parser(metaclass=_Parser):
self._match_set(self.JOIN_KINDS) and self._prev,
)
- def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]:
if self._match(TokenType.COMMA):
return self.expression(exp.Join, this=self._parse_table())
@@ -2223,7 +2229,7 @@ class Parser(metaclass=_Parser):
def _parse_index(
self,
index: t.Optional[exp.Expression] = None,
- ) -> t.Optional[exp.Expression]:
+ ) -> t.Optional[exp.Index]:
if index:
unique = None
primary = None
@@ -2236,11 +2242,15 @@ class Parser(metaclass=_Parser):
unique = self._match(TokenType.UNIQUE)
primary = self._match_text_seq("PRIMARY")
amp = self._match_text_seq("AMP")
+
if not self._match(TokenType.INDEX):
return None
+
index = self._parse_id_var()
table = None
+ using = self._parse_field() if self._match(TokenType.USING) else None
+
if self._match(TokenType.L_PAREN, advance=False):
columns = self._parse_wrapped_csv(self._parse_ordered)
else:
@@ -2250,6 +2260,7 @@ class Parser(metaclass=_Parser):
exp.Index,
this=index,
table=table,
+ using=using,
columns=columns,
unique=unique,
primary=primary,
@@ -2259,7 +2270,7 @@ class Parser(metaclass=_Parser):
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
return (
- (not schema and self._parse_function())
+ (not schema and self._parse_function(optional_parens=False))
or self._parse_id_var(any_token=False)
or self._parse_string_as_identifier()
or self._parse_placeholder()
@@ -2314,7 +2325,7 @@ class Parser(metaclass=_Parser):
if schema:
return self._parse_schema(this=this)
- if self.alias_post_tablesample:
+ if self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
@@ -2331,7 +2342,7 @@ class Parser(metaclass=_Parser):
)
self._match_r_paren()
- if not self.alias_post_tablesample:
+ if not self.ALIAS_POST_TABLESAMPLE:
table_sample = self._parse_table_sample()
if table_sample:
@@ -2340,46 +2351,47 @@ class Parser(metaclass=_Parser):
return this
- def _parse_unnest(self) -> t.Optional[exp.Expression]:
+ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
if not self._match(TokenType.UNNEST):
return None
expressions = self._parse_wrapped_csv(self._parse_type)
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
- alias = self._parse_table_alias()
- if alias and self.unnest_column_only:
+ alias = self._parse_table_alias() if with_alias else None
+
+ if alias and self.UNNEST_COLUMN_ONLY:
if alias.args.get("columns"):
self.raise_error("Unexpected extra column alias in unnest.")
+
alias.set("columns", [alias.this])
alias.set("this", None)
offset = None
if self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
- offset = self._parse_id_var() or exp.Identifier(this="offset")
+ offset = self._parse_id_var() or exp.to_identifier("offset")
return self.expression(
- exp.Unnest,
- expressions=expressions,
- ordinality=ordinality,
- alias=alias,
- offset=offset,
+ exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias, offset=offset
)
- def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
+ def _parse_derived_table_values(self) -> t.Optional[exp.Values]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
return None
expressions = self._parse_csv(self._parse_value)
+ alias = self._parse_table_alias()
if is_derived:
self._match_r_paren()
- return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
+ return self.expression(
+ exp.Values, expressions=expressions, alias=alias or self._parse_table_alias()
+ )
- def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]:
if not self._match(TokenType.TABLE_SAMPLE) and not (
as_modifier and self._match_text_seq("USING", "SAMPLE")
):
@@ -2456,7 +2468,7 @@ class Parser(metaclass=_Parser):
exp.Pivot, this=this, expressions=expressions, using=using, group=group
)
- def _parse_pivot(self) -> t.Optional[exp.Expression]:
+ def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
if self._match(TokenType.PIVOT):
@@ -2519,7 +2531,7 @@ class Parser(metaclass=_Parser):
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
return [agg.alias for agg in aggregations]
- def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
@@ -2527,7 +2539,7 @@ class Parser(metaclass=_Parser):
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
- def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]:
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
@@ -2578,12 +2590,12 @@ class Parser(metaclass=_Parser):
return self._parse_column()
- def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]:
if not skip_having_token and not self._match(TokenType.HAVING):
return None
return self.expression(exp.Having, this=self._parse_conjunction())
- def _parse_qualify(self) -> t.Optional[exp.Expression]:
+ def _parse_qualify(self) -> t.Optional[exp.Qualify]:
if not self._match(TokenType.QUALIFY):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
@@ -2598,16 +2610,15 @@ class Parser(metaclass=_Parser):
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
- def _parse_sort(
- self, exp_class: t.Type[exp.Expression], *texts: str
- ) -> t.Optional[exp.Expression]:
+ def _parse_sort(self, exp_class: t.Type[E], *texts: str) -> t.Optional[E]:
if not self._match_text_seq(*texts):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
- def _parse_ordered(self) -> exp.Expression:
+ def _parse_ordered(self) -> exp.Ordered:
this = self._parse_conjunction()
self._match(TokenType.ASC)
+
is_desc = self._match(TokenType.DESC)
is_nulls_first = self._match_text_seq("NULLS", "FIRST")
is_nulls_last = self._match_text_seq("NULLS", "LAST")
@@ -2615,13 +2626,14 @@ class Parser(metaclass=_Parser):
asc = not desc
nulls_first = is_nulls_first or False
explicitly_null_ordered = is_nulls_first or is_nulls_last
+
if (
not explicitly_null_ordered
and (
- (asc and self.null_ordering == "nulls_are_small")
- or (desc and self.null_ordering != "nulls_are_small")
+ (asc and self.NULL_ORDERING == "nulls_are_small")
+ or (desc and self.NULL_ORDERING != "nulls_are_small")
)
- and self.null_ordering != "nulls_are_last"
+ and self.NULL_ORDERING != "nulls_are_last"
):
nulls_first = True
@@ -2632,9 +2644,15 @@ class Parser(metaclass=_Parser):
) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
- limit_exp = self.expression(
- exp.Limit, this=this, expression=self._parse_number() if top else self._parse_term()
- )
+ expression = self._parse_number() if top else self._parse_term()
+
+ if self._match(TokenType.COMMA):
+ offset = expression
+ expression = self._parse_term()
+ else:
+ offset = None
+
+ limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset)
if limit_paren:
self._match_r_paren()
@@ -2667,17 +2685,15 @@ class Parser(metaclass=_Parser):
return this
def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
- if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
+ if not self._match(TokenType.OFFSET):
return this
count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
- def _parse_locks(self) -> t.List[exp.Expression]:
- # Lists are invariant, so we need to use a type hint here
- locks: t.List[exp.Expression] = []
-
+ def _parse_locks(self) -> t.List[exp.Lock]:
+ locks = []
while True:
if self._match_text_seq("FOR", "UPDATE"):
update = True
@@ -2768,6 +2784,7 @@ class Parser(metaclass=_Parser):
def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
index = self._index - 1
negate = self._match(TokenType.NOT)
+
if self._match_text_seq("DISTINCT", "FROM"):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
@@ -2781,7 +2798,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Not, this=this) if negate else this
def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In:
- unnest = self._parse_unnest()
+ unnest = self._parse_unnest(with_alias=False)
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
elif self._match(TokenType.L_PAREN):
@@ -2798,7 +2815,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_between(self, this: exp.Expression) -> exp.Expression:
+ def _parse_between(self, this: exp.Expression) -> exp.Between:
low = self._parse_bitwise()
self._match(TokenType.AND)
high = self._parse_bitwise()
@@ -2809,7 +2826,7 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
- def _parse_interval(self) -> t.Optional[exp.Expression]:
+ def _parse_interval(self) -> t.Optional[exp.Interval]:
if not self._match(TokenType.INTERVAL):
return None
@@ -2840,9 +2857,7 @@ class Parser(metaclass=_Parser):
while True:
if self._match_set(self.BITWISE):
this = self.expression(
- self.BITWISE[self._prev.token_type],
- this=this,
- expression=self._parse_term(),
+ self.BITWISE[self._prev.token_type], this=this, expression=self._parse_term()
)
elif self._match_pair(TokenType.LT, TokenType.LT):
this = self.expression(
@@ -2890,7 +2905,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_type_size(self) -> t.Optional[exp.Expression]:
+ def _parse_type_size(self) -> t.Optional[exp.DataTypeSize]:
this = self._parse_type()
if not this:
return None
@@ -2926,6 +2941,8 @@ class Parser(metaclass=_Parser):
expressions = self._parse_csv(
lambda: self._parse_types(check_func=check_func, schema=schema)
)
+ elif type_token in self.ENUM_TYPE_TOKENS:
+ expressions = self._parse_csv(self._parse_primary)
else:
expressions = self._parse_csv(self._parse_type_size)
@@ -2943,11 +2960,7 @@ class Parser(metaclass=_Parser):
)
while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
- this = exp.DataType(
- this=exp.DataType.Type.ARRAY,
- expressions=[this],
- nested=True,
- )
+ this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True)
return this
@@ -2973,23 +2986,14 @@ class Parser(metaclass=_Parser):
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
- if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ:
+ if self._match_text_seq("WITH", "TIME", "ZONE"):
+ maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
- elif (
- self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE")
- or type_token == TokenType.TIMESTAMPLTZ
- ):
+ elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"):
+ maybe_func = False
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
- if type_token == TokenType.TIME:
- value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
- else:
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
-
- maybe_func = maybe_func and value is None
-
- if value is None:
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
+ maybe_func = False
elif type_token == TokenType.INTERVAL:
unit = self._parse_var()
@@ -3037,7 +3041,7 @@ class Parser(metaclass=_Parser):
return self._parse_bracket(this)
return self._parse_column_ops(this)
- def _parse_column_ops(self, this: exp.Expression) -> exp.Expression:
+ def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
this = self._parse_bracket(this)
while self._match_set(self.COLUMN_OPERATORS):
@@ -3057,7 +3061,7 @@ class Parser(metaclass=_Parser):
else exp.Literal.string(value)
)
else:
- field = self._parse_field(anonymous_func=True)
+ field = self._parse_field(anonymous_func=True, any_token=True)
if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
@@ -3089,8 +3093,10 @@ class Parser(metaclass=_Parser):
expressions = [primary]
while self._match(TokenType.STRING):
expressions.append(exp.Literal.string(self._prev.text))
+
if len(expressions) > 1:
return self.expression(exp.Concat, expressions=expressions)
+
return primary
if self._match_pair(TokenType.DOT, TokenType.NUMBER):
@@ -3118,8 +3124,8 @@ class Parser(metaclass=_Parser):
if this:
this.add_comments(comments)
- self._match_r_paren(expression=this)
+ self._match_r_paren(expression=this)
return this
return None
@@ -3137,18 +3143,21 @@ class Parser(metaclass=_Parser):
)
def _parse_function(
- self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
+ self,
+ functions: t.Optional[t.Dict[str, t.Callable]] = None,
+ anonymous: bool = False,
+ optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
if not self._curr:
return None
token_type = self._curr.token_type
- if self._match_set(self.NO_PAREN_FUNCTION_PARSERS):
+ if optional_parens and self._match_set(self.NO_PAREN_FUNCTION_PARSERS):
return self.NO_PAREN_FUNCTION_PARSERS[token_type](self)
if not self._next or self._next.token_type != TokenType.L_PAREN:
- if token_type in self.NO_PAREN_FUNCTIONS:
+ if optional_parens and token_type in self.NO_PAREN_FUNCTIONS:
self._advance()
return self.expression(self.NO_PAREN_FUNCTIONS[token_type])
@@ -3182,8 +3191,7 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
if function and not anonymous:
- this = function(args)
- self.validate_expression(this, args)
+ this = self.validate_expression(function(args), args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -3210,14 +3218,14 @@ class Parser(metaclass=_Parser):
exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True
)
- def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
+ def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier:
literal = self._parse_primary()
if literal:
return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
- def _parse_session_parameter(self) -> exp.Expression:
+ def _parse_session_parameter(self) -> exp.SessionParameter:
kind = None
this = self._parse_id_var() or self._parse_primary()
@@ -3255,7 +3263,7 @@ class Parser(metaclass=_Parser):
if isinstance(this, exp.EQ):
left = this.this
if isinstance(left, exp.Column):
- left.replace(exp.Var(this=left.text("this")))
+ left.replace(exp.var(left.text("this")))
return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
@@ -3279,6 +3287,7 @@ class Parser(metaclass=_Parser):
lambda: self._parse_constraint()
or self._parse_column_def(self._parse_field(any_token=True))
)
+
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@@ -3286,6 +3295,7 @@ class Parser(metaclass=_Parser):
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
this = this.this
+
kind = self._parse_types(schema=True)
if self._match_text_seq("FOR", "ORDINALITY"):
@@ -3303,7 +3313,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
- def _parse_auto_increment(self) -> exp.Expression:
+ def _parse_auto_increment(
+ self,
+ ) -> exp.GeneratedAsIdentityColumnConstraint | exp.AutoIncrementColumnConstraint:
start = None
increment = None
@@ -3321,7 +3333,7 @@ class Parser(metaclass=_Parser):
return exp.AutoIncrementColumnConstraint()
- def _parse_compress(self) -> exp.Expression:
+ def _parse_compress(self) -> exp.CompressColumnConstraint:
if self._match(TokenType.L_PAREN, advance=False):
return self.expression(
exp.CompressColumnConstraint, this=self._parse_wrapped_csv(self._parse_bitwise)
@@ -3329,7 +3341,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise())
- def _parse_generated_as_identity(self) -> exp.Expression:
+ def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint:
if self._match_text_seq("BY", "DEFAULT"):
on_null = self._match_pair(TokenType.ON, TokenType.NULL)
this = self.expression(
@@ -3364,11 +3376,13 @@ class Parser(metaclass=_Parser):
return this
- def _parse_inline(self) -> t.Optional[exp.Expression]:
+ def _parse_inline(self) -> exp.InlineLengthColumnConstraint:
self._match_text_seq("LENGTH")
return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise())
- def _parse_not_constraint(self) -> t.Optional[exp.Expression]:
+ def _parse_not_constraint(
+ self,
+ ) -> t.Optional[exp.NotNullColumnConstraint | exp.CaseSpecificColumnConstraint]:
if self._match_text_seq("NULL"):
return self.expression(exp.NotNullColumnConstraint)
if self._match_text_seq("CASESPECIFIC"):
@@ -3417,7 +3431,7 @@ class Parser(metaclass=_Parser):
return self.CONSTRAINT_PARSERS[constraint](self)
- def _parse_unique(self) -> exp.Expression:
+ def _parse_unique(self) -> exp.UniqueColumnConstraint:
self._match_text_seq("KEY")
return self.expression(
exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False))
@@ -3460,7 +3474,7 @@ class Parser(metaclass=_Parser):
return options
- def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]:
+ def _parse_references(self, match: bool = True) -> t.Optional[exp.Reference]:
if match and not self._match(TokenType.REFERENCES):
return None
@@ -3473,7 +3487,7 @@ class Parser(metaclass=_Parser):
options = self._parse_key_constraint_options()
return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
- def _parse_foreign_key(self) -> exp.Expression:
+ def _parse_foreign_key(self) -> exp.ForeignKey:
expressions = self._parse_wrapped_id_vars()
reference = self._parse_references()
options = {}
@@ -3501,7 +3515,7 @@ class Parser(metaclass=_Parser):
def _parse_primary_key(
self, wrapped_optional: bool = False, in_props: bool = False
- ) -> exp.Expression:
+ ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey:
desc = (
self._match_set((TokenType.ASC, TokenType.DESC))
and self._prev.token_type == TokenType.DESC
@@ -3514,15 +3528,7 @@ class Parser(metaclass=_Parser):
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
- @t.overload
- def _parse_bracket(self, this: exp.Expression) -> exp.Expression:
- ...
-
- @t.overload
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
- ...
-
- def _parse_bracket(self, this):
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
@@ -3541,7 +3547,7 @@ class Parser(metaclass=_Parser):
elif not this or this.name.upper() == "ARRAY":
this = self.expression(exp.Array, expressions=expressions)
else:
- expressions = apply_index_offset(this, expressions, -self.index_offset)
+ expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
@@ -3582,8 +3588,7 @@ class Parser(metaclass=_Parser):
def _parse_if(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
args = self._parse_csv(self._parse_conjunction)
- this = exp.If.from_arg_list(args)
- self.validate_expression(this, args)
+ this = self.validate_expression(exp.If.from_arg_list(args), args)
self._match_r_paren()
else:
index = self._index - 1
@@ -3601,7 +3606,7 @@ class Parser(metaclass=_Parser):
return self._parse_window(this)
- def _parse_extract(self) -> exp.Expression:
+ def _parse_extract(self) -> exp.Extract:
this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
@@ -3630,9 +3635,37 @@ class Parser(metaclass=_Parser):
elif to.this == exp.DataType.Type.CHAR:
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
+ elif to.this in exp.DataType.TEMPORAL_TYPES and self._match(TokenType.FORMAT):
+ fmt = self._parse_string()
+
+ return self.expression(
+ exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
+ this=this,
+ format=exp.Literal.string(
+ format_time(
+ fmt.this if fmt else "",
+ self.FORMAT_MAPPING or self.TIME_MAPPING,
+ self.FORMAT_TRIE or self.TIME_TRIE,
+ )
+ ),
+ )
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_concat(self) -> t.Optional[exp.Expression]:
+ args = self._parse_csv(self._parse_conjunction)
+ if self.CONCAT_NULL_OUTPUTS_STRING:
+ args = [exp.func("COALESCE", arg, exp.Literal.string("")) for arg in args]
+
+ # Some dialects (e.g. Trino) don't allow a single-argument CONCAT call, so when
+ # we find such a call we replace it with its argument.
+ if len(args) == 1:
+ return args[0]
+
+ return self.expression(
+ exp.Concat if self.STRICT_STRING_CONCAT else exp.SafeConcat, expressions=args
+ )
+
def _parse_string_agg(self) -> exp.Expression:
expression: t.Optional[exp.Expression]
@@ -3654,9 +3687,7 @@ class Parser(metaclass=_Parser):
# the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them.
if not self._match_text_seq("WITHIN", "GROUP"):
self._retreat(index)
- this = exp.GroupConcat.from_arg_list(args)
- self.validate_expression(this, args)
- return this
+ return self.validate_expression(exp.GroupConcat.from_arg_list(args), args)
self._match_l_paren() # The corresponding match_r_paren will be called in parse_function (caller)
order = self._parse_order(this=expression)
@@ -3679,7 +3710,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_decode(self) -> t.Optional[exp.Expression]:
+ def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]:
"""
There are generally two variants of the DECODE function:
@@ -3726,18 +3757,20 @@ class Parser(metaclass=_Parser):
return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None)
- def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
+ def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
self._match_text_seq("KEY")
key = self._parse_field()
self._match(TokenType.COLON)
self._match_text_seq("VALUE")
value = self._parse_field()
+
if not key and not value:
return None
return self.expression(exp.JSONKeyValue, this=key, expression=value)
- def _parse_json_object(self) -> exp.Expression:
- expressions = self._parse_csv(self._parse_json_key_value)
+ def _parse_json_object(self) -> exp.JSONObject:
+ star = self._parse_star()
+ expressions = [star] if star else self._parse_csv(self._parse_json_key_value)
null_handling = None
if self._match_text_seq("NULL", "ON", "NULL"):
@@ -3767,7 +3800,7 @@ class Parser(metaclass=_Parser):
encoding=encoding,
)
- def _parse_logarithm(self) -> exp.Expression:
+ def _parse_logarithm(self) -> exp.Func:
# Default argument order is base, expression
args = self._parse_csv(self._parse_range)
@@ -3780,7 +3813,7 @@ class Parser(metaclass=_Parser):
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
)
- def _parse_match_against(self) -> exp.Expression:
+ def _parse_match_against(self) -> exp.MatchAgainst:
expressions = self._parse_csv(self._parse_column)
self._match_text_seq(")", "AGAINST", "(")
@@ -3803,15 +3836,16 @@ class Parser(metaclass=_Parser):
)
# https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16
- def _parse_open_json(self) -> exp.Expression:
+ def _parse_open_json(self) -> exp.OpenJSON:
this = self._parse_bitwise()
path = self._match(TokenType.COMMA) and self._parse_string()
- def _parse_open_json_column_def() -> exp.Expression:
+ def _parse_open_json_column_def() -> exp.OpenJSONColumnDef:
this = self._parse_field(any_token=True)
kind = self._parse_types()
path = self._parse_string()
as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON)
+
return self.expression(
exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json
)
@@ -3823,7 +3857,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions)
- def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
+ def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition:
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
@@ -3838,17 +3872,15 @@ class Parser(metaclass=_Parser):
needle = seq_get(args, 0)
haystack = seq_get(args, 1)
- this = exp.StrPosition(this=haystack, substr=needle, position=seq_get(args, 2))
-
- self.validate_expression(this, args)
-
- return this
+ return self.expression(
+ exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2)
+ )
- def _parse_join_hint(self, func_name: str) -> exp.Expression:
+ def _parse_join_hint(self, func_name: str) -> exp.JoinHint:
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
- def _parse_substring(self) -> exp.Expression:
+ def _parse_substring(self) -> exp.Substring:
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@@ -3859,12 +3891,9 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FOR):
args.append(self._parse_bitwise())
- this = exp.Substring.from_arg_list(args)
- self.validate_expression(this, args)
-
- return this
+ return self.validate_expression(exp.Substring.from_arg_list(args), args)
- def _parse_trim(self) -> exp.Expression:
+ def _parse_trim(self) -> exp.Trim:
# https://www.w3resource.com/sql/character-functions/trim.php
# https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
@@ -3885,11 +3914,7 @@ class Parser(metaclass=_Parser):
collation = self._parse_bitwise()
return self.expression(
- exp.Trim,
- this=this,
- position=position,
- expression=expression,
- collation=collation,
+ exp.Trim, this=this, position=position, expression=expression, collation=collation
)
def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
@@ -4047,7 +4072,7 @@ class Parser(metaclass=_Parser):
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
- def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]:
+ def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]:
return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True)
def _parse_number(self) -> t.Optional[exp.Expression]:
@@ -4097,7 +4122,7 @@ class Parser(metaclass=_Parser):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
- def _parse_parameter(self) -> exp.Expression:
+ def _parse_parameter(self) -> exp.Parameter:
wrapped = self._match(TokenType.L_BRACE)
this = self._parse_var() or self._parse_identifier() or self._parse_primary()
self._match(TokenType.R_BRACE)
@@ -4183,7 +4208,7 @@ class Parser(metaclass=_Parser):
self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False))
)
- def _parse_transaction(self) -> exp.Expression:
+ def _parse_transaction(self) -> exp.Transaction:
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
@@ -4203,7 +4228,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
- def _parse_commit_or_rollback(self) -> exp.Expression:
+ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@@ -4220,6 +4245,7 @@ class Parser(metaclass=_Parser):
if is_rollback:
return self.expression(exp.Rollback, savepoint=savepoint)
+
return self.expression(exp.Commit, chain=chain)
def _parse_add_column(self) -> t.Optional[exp.Expression]:
@@ -4243,19 +4269,19 @@ class Parser(metaclass=_Parser):
return expression
- def _parse_drop_column(self) -> t.Optional[exp.Expression]:
+ def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]:
drop = self._match(TokenType.DROP) and self._parse_drop()
if drop and not isinstance(drop, exp.Command):
drop.set("kind", drop.args.get("kind", "COLUMN"))
return drop
# https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html
- def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression:
+ def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.DropPartition:
return self.expression(
exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists
)
- def _parse_add_constraint(self) -> t.Optional[exp.Expression]:
+ def _parse_add_constraint(self) -> exp.AddConstraint:
this = None
kind = self._prev.token_type
@@ -4288,7 +4314,7 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_add_column)
- def _parse_alter_table_alter(self) -> exp.Expression:
+ def _parse_alter_table_alter(self) -> exp.AlterColumn:
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
@@ -4316,11 +4342,11 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return self._parse_csv(self._parse_drop_column)
- def _parse_alter_table_rename(self) -> exp.Expression:
+ def _parse_alter_table_rename(self) -> exp.RenameTable:
self._match_text_seq("TO")
return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
- def _parse_alter(self) -> t.Optional[exp.Expression]:
+ def _parse_alter(self) -> exp.AlterTable | exp.Command:
start = self._prev
if not self._match(TokenType.TABLE):
@@ -4345,7 +4371,7 @@ class Parser(metaclass=_Parser):
)
return self._parse_as_command(start)
- def _parse_merge(self) -> exp.Expression:
+ def _parse_merge(self) -> exp.Merge:
self._match(TokenType.INTO)
target = self._parse_table()
@@ -4412,7 +4438,7 @@ class Parser(metaclass=_Parser):
)
def _parse_show(self) -> t.Optional[exp.Expression]:
- parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
+ parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE)
if parser:
return parser(self)
self._advance()
@@ -4433,17 +4459,9 @@ class Parser(metaclass=_Parser):
return None
right = self._parse_statement() or self._parse_id_var()
- this = self.expression(
- exp.EQ,
- this=left,
- expression=right,
- )
+ this = self.expression(exp.EQ, this=left, expression=right)
- return self.expression(
- exp.SetItem,
- this=this,
- kind=kind,
- )
+ return self.expression(exp.SetItem, this=this, kind=kind)
def _parse_set_transaction(self, global_: bool = False) -> exp.Expression:
self._match_text_seq("TRANSACTION")
@@ -4458,10 +4476,10 @@ class Parser(metaclass=_Parser):
)
def _parse_set_item(self) -> t.Optional[exp.Expression]:
- parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
+ parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE)
return parser(self) if parser else self._parse_set_item_assignment(kind=None)
- def _parse_set(self) -> exp.Expression:
+ def _parse_set(self) -> exp.Set | exp.Command:
index = self._index
set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
@@ -4471,10 +4489,10 @@ class Parser(metaclass=_Parser):
return set_
- def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]:
+ def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]:
for option in options:
if self._match_text_seq(*option.split(" ")):
- return exp.Var(this=option)
+ return exp.var(option)
return None
def _parse_as_command(self, start: Token) -> exp.Command:
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index eccad35..4ed7449 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -302,7 +302,7 @@ class Join(Step):
for join in joins:
source_key, join_key, condition = join_condition(join)
- step.joins[join.this.alias_or_name] = {
+ step.joins[join.alias_or_name] = {
"side": join.side, # type: ignore
"join_key": join_key,
"source_key": source_key,
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f1c4a09..f73adee 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -285,8 +285,6 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
elif isinstance(column_type, str):
return self._to_data_type(column_type.upper(), dialect=dialect)
- raise SchemaError(f"Unknown column type '{column_type}'")
-
return exp.DataType.build("unknown")
def _normalize(self, schema: t.Dict) -> t.Dict:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index a30ec24..42628b9 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -144,6 +144,7 @@ class TokenType(AutoName):
VARIANT = auto()
OBJECT = auto()
INET = auto()
+ ENUM = auto()
# keywords
ALIAS = auto()
@@ -346,6 +347,7 @@ class Token:
col: The column that the token ends on.
start: The start index of the token.
end: The ending index of the token.
+ comments: The comments to attach to the token.
"""
self.token_type = token_type
self.text = text
@@ -391,12 +393,15 @@ class _Tokenizer(type):
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES)
- klass._COMMENTS = dict(
- (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
- for comment in klass.COMMENTS
- )
+ klass._COMMENTS = {
+ **dict(
+ (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
+ for comment in klass.COMMENTS
+ ),
+ "{#": "#}", # Ensure Jinja comments are tokenized correctly in all dialects
+ }
- klass.KEYWORD_TRIE = new_trie(
+ klass._KEYWORD_TRIE = new_trie(
key.upper()
for key in (
*klass.KEYWORDS,
@@ -456,20 +461,22 @@ class Tokenizer(metaclass=_Tokenizer):
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
+ # Autofilled
+ IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
+
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
_IDENTIFIERS: t.Dict[str, str] = {}
_IDENTIFIER_ESCAPES: t.Set[str] = set()
_QUOTES: t.Dict[str, str] = {}
_STRING_ESCAPES: t.Set[str] = set()
+ _KEYWORD_TRIE: t.Dict = {}
- KEYWORDS: t.Dict[t.Optional[str], TokenType] = {
+ KEYWORDS: t.Dict[str, TokenType] = {
**{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")},
**{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")},
- "{{+": TokenType.BLOCK_START,
- "{{-": TokenType.BLOCK_START,
- "+}}": TokenType.BLOCK_END,
- "-}}": TokenType.BLOCK_END,
+ **{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")},
+ **{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")},
"/*+": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
@@ -594,6 +601,7 @@ class Tokenizer(metaclass=_Tokenizer):
"RECURSIVE": TokenType.RECURSIVE,
"REGEXP": TokenType.RLIKE,
"REPLACE": TokenType.REPLACE,
+ "RETURNING": TokenType.RETURNING,
"REFERENCES": TokenType.REFERENCES,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
@@ -732,8 +740,7 @@ class Tokenizer(metaclass=_Tokenizer):
NUMERIC_LITERALS: t.Dict[str, str] = {}
ENCODE: t.Optional[str] = None
- COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
- KEYWORD_TRIE: t.Dict = {} # autofilled
+ COMMENTS = ["--", ("/*", "*/")]
__slots__ = (
"sql",
@@ -748,7 +755,6 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
- "identifiers_can_start_with_digit",
)
def __init__(self) -> None:
@@ -894,7 +900,7 @@ class Tokenizer(metaclass=_Tokenizer):
char = chars
prev_space = False
skip = False
- trie = self.KEYWORD_TRIE
+ trie = self._KEYWORD_TRIE
single_token = char in self.SINGLE_TOKENS
while chars:
@@ -994,7 +1000,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
elif self._peek == "." and not decimal:
after = self.peek(1)
- if after.isdigit() or not after.strip():
+ if after.isdigit() or not after.isalpha():
decimal = True
self._advance()
else:
@@ -1013,13 +1019,13 @@ class Tokenizer(metaclass=_Tokenizer):
literal += self._peek.upper()
self._advance()
- token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, ""))
if token_type:
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
- elif self.identifiers_can_start_with_digit: # type: ignore
+ elif self.IDENTIFIERS_CAN_START_WITH_DIGIT:
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)