summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-07-06 07:28:12 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-07-06 07:28:12 +0000
commit374a0f6318bcf423b1b784d30b25a8327c65cb24 (patch)
tree9303a1cbdba85b5d9781ebef32eb1902d3790c99 /sqlglot
parentReleasing debian version 16.7.7-1. (diff)
downloadsqlglot-374a0f6318bcf423b1b784d30b25a8327c65cb24.tar.xz
sqlglot-374a0f6318bcf423b1b784d30b25a8327c65cb24.zip
Merging upstream version 17.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py9
-rw-r--r--sqlglot/dialects/bigquery.py36
-rw-r--r--sqlglot/dialects/clickhouse.py17
-rw-r--r--sqlglot/dialects/dialect.py1
-rw-r--r--sqlglot/dialects/drill.py1
-rw-r--r--sqlglot/dialects/duckdb.py21
-rw-r--r--sqlglot/dialects/hive.py15
-rw-r--r--sqlglot/dialects/mysql.py29
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/postgres.py26
-rw-r--r--sqlglot/dialects/presto.py1
-rw-r--r--sqlglot/dialects/redshift.py1
-rw-r--r--sqlglot/dialects/snowflake.py1
-rw-r--r--sqlglot/dialects/spark2.py3
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/dialects/tableau.py1
-rw-r--r--sqlglot/dialects/teradata.py3
-rw-r--r--sqlglot/dialects/tsql.py7
-rw-r--r--sqlglot/executor/context.py2
-rw-r--r--sqlglot/executor/env.py19
-rw-r--r--sqlglot/executor/python.py4
-rw-r--r--sqlglot/expressions.py29
-rw-r--r--sqlglot/generator.py91
-rw-r--r--sqlglot/optimizer/annotate_types.py1
-rw-r--r--sqlglot/optimizer/qualify_columns.py19
-rw-r--r--sqlglot/optimizer/simplify.py26
-rw-r--r--sqlglot/parser.py158
-rw-r--r--sqlglot/planner.py29
-rw-r--r--sqlglot/tokens.py3
-rw-r--r--sqlglot/transforms.py23
30 files changed, 396 insertions, 182 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 739ec29..42801ac 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -94,7 +94,11 @@ def parse_one(sql: str, **opts) -> Expression:
def parse_one(
- sql: str, read: DialectType = None, into: t.Optional[exp.IntoType] = None, **opts
+ sql: str,
+ read: DialectType = None,
+ dialect: DialectType = None,
+ into: t.Optional[exp.IntoType] = None,
+ **opts,
) -> Expression:
"""
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
@@ -102,6 +106,7 @@ def parse_one(
Args:
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
+ dialect: the SQL dialect (alias for read)
into: the SQLGlot Expression to parse into.
**opts: other `sqlglot.parser.Parser` options.
@@ -109,7 +114,7 @@ def parse_one(
The syntax tree for the first parsed statement.
"""
- dialect = Dialect.get_or_raise(read)()
+ dialect = Dialect.get_or_raise(read or dialect)()
if into:
result = dialect.parse_into(into, sql, **opts)
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index c9d6c79..82162b4 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -163,6 +163,17 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
return expression
+def _parse_timestamp(args: t.List) -> exp.StrToTime:
+ this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
+ this.set("zone", seq_get(args, 2))
+ return this
+
+
+def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
+ expr_type = exp.DateFromParts if len(args) == 3 else exp.Date
+ return expr_type.from_arg_list(args)
+
+
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
@@ -203,8 +214,10 @@ class BigQuery(Dialect):
while isinstance(parent, exp.Dot):
parent = parent.parent
- if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get(
- "is_table"
+ if (
+ not isinstance(parent, exp.UserDefinedFunction)
+ and not (isinstance(parent, exp.Table) and parent.db)
+ and not expression.meta.get("is_table")
):
expression.set("this", expression.this.lower())
@@ -251,6 +264,7 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "DATE": _parse_date,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"DATE_TRUNC": lambda args: exp.DateTrunc(
@@ -264,9 +278,7 @@ class BigQuery(Dialect):
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
- "PARSE_TIMESTAMP": lambda args: format_time_lambda(exp.StrToTime, "bigquery")(
- [seq_get(args, 1), seq_get(args, 0)]
- ),
+ "PARSE_TIMESTAMP": _parse_timestamp,
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@@ -356,9 +368,11 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
+ QUERY_HINTS = False
TABLE_HINTS = False
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
+ ESCAPE_LINE_BREAK = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -367,6 +381,7 @@ class BigQuery(Dialect):
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
+ exp.DateFromParts: rename_func("DATE"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
@@ -397,7 +412,9 @@ class BigQuery(Dialect):
]
),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
- exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
+ exp.StrToTime: lambda self, e: self.func(
+ "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
+ ),
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
@@ -548,10 +565,15 @@ class BigQuery(Dialect):
}
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
- if not isinstance(expression.parent, exp.Cast):
+ parent = expression.parent
+
+ # BigQuery allows CAST(.. AS {STRING|TIMESTAMP} [FORMAT <fmt> [AT TIME ZONE <tz>]]).
+ # Only the TIMESTAMP one should use the below conversion, when AT TIME ZONE is included.
+ if not isinstance(parent, exp.Cast) or not parent.to.is_type("text"):
return self.func(
"TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone"))
)
+
return super().attimezone_sql(expression)
def trycast_sql(self, expression: exp.TryCast) -> str:
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index efaf34c..9126c4b 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -109,10 +109,11 @@ class ClickHouse(Dialect):
QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
- "settings": lambda self: self._parse_csv(self._parse_conjunction)
- if self._match(TokenType.SETTINGS)
- else None,
- "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
+ TokenType.SETTINGS: lambda self: (
+ "settings",
+ self._advance() or self._parse_csv(self._parse_conjunction),
+ ),
+ TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()),
}
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
@@ -155,9 +156,12 @@ class ClickHouse(Dialect):
return this
def _parse_table(
- self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ self,
+ schema: bool = False,
+ joins: bool = False,
+ alias_tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
- this = super()._parse_table(schema=schema, alias_tokens=alias_tokens)
+ this = super()._parse_table(schema=schema, joins=joins, alias_tokens=alias_tokens)
if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this)
@@ -273,6 +277,7 @@ class ClickHouse(Dialect):
return None
class Generator(generator.Generator):
+ QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index d258826..4fc93bf 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -98,7 +98,6 @@ class _Dialect(type):
klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING)
klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING)
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
- klass.RAW_START, klass.RAW_END = get_start_end(TokenType.RAW_STRING)
dialect_properties = {
**{
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 3cca986..26d09ce 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -96,6 +96,7 @@ class Drill(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 093a01c..d7e5a43 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
+ date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
no_comment_column_constraint_sql,
@@ -38,6 +39,21 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
+# BigQuery -> DuckDB conversion for the DATE function
+def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
+ result = f"CAST({self.sql(expression, 'this')} AS DATE)"
+ zone = self.sql(expression, "zone")
+
+ if zone:
+ date_str = self.func("STRFTIME", result, "'%d/%m/%Y'")
+ date_str = f"{date_str} || ' ' || {zone}"
+
+ # This will create a TIMESTAMP with time zone information
+ result = self.func("STRPTIME", date_str, "'%d/%m/%Y %Z'")
+
+ return result
+
+
def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
@@ -131,6 +147,8 @@ class DuckDB(Dialect):
"ARRAY_REVERSE_SORT": _sort_array_reverse,
"DATEDIFF": _parse_date_diff,
"DATE_DIFF": _parse_date_diff,
+ "DATE_TRUNC": date_trunc_to_time,
+ "DATETRUNC": date_trunc_to_time,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(this=seq_get(args, 0), expression=exp.Literal.number(1000))
@@ -167,6 +185,7 @@ class DuckDB(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
@@ -188,7 +207,9 @@ class DuckDB(Dialect):
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
+ exp.Date: _date_sql,
exp.DateAdd: _date_delta_sql,
+ exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 6bca610..1abc0f4 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -273,13 +273,6 @@ class Hive(Dialect):
),
}
- QUERY_MODIFIER_PARSERS = {
- **parser.Parser.QUERY_MODIFIER_PARSERS,
- "cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
- "distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
- "sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
- }
-
def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@@ -319,6 +312,7 @@ class Hive(Dialect):
TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
INDEX_ON = "ON TABLE"
TYPE_MAPPING = {
@@ -429,10 +423,3 @@ class Hive(Dialect):
expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)
-
- def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
- return super().after_having_modifiers(expression) + [
- self.sql(expression, "distribute"),
- self.sql(expression, "sort"),
- self.sql(expression, "cluster"),
- ]
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5f743ee..bae0e50 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -123,14 +123,15 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"CHARSET": TokenType.CHARACTER_SET,
+ "ENUM": TokenType.ENUM,
"FORCE": TokenType.FORCE,
"IGNORE": TokenType.IGNORE,
"LONGBLOB": TokenType.LONGBLOB,
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
+ "MEMBER OF": TokenType.MEMBER_OF,
"SEPARATOR": TokenType.SEPARATOR,
- "ENUM": TokenType.ENUM,
"START": TokenType.BEGIN,
"SIGNED": TokenType.BIGINT,
"SIGNED INTEGER": TokenType.BIGINT,
@@ -185,11 +186,26 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
- FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE}
+ FUNC_TOKENS = {
+ *parser.Parser.FUNC_TOKENS,
+ TokenType.DATABASE,
+ TokenType.SCHEMA,
+ TokenType.VALUES,
+ }
+
TABLE_ALIAS_TOKENS = (
parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS
)
+ RANGE_PARSERS = {
+ **parser.Parser.RANGE_PARSERS,
+ TokenType.MEMBER_OF: lambda self, this: self.expression(
+ exp.JSONArrayContains,
+ this=this,
+ expression=self._parse_wrapped(self._parse_expression),
+ ),
+ }
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
@@ -207,6 +223,10 @@ class MySQL(Dialect):
this=self._parse_lambda(),
separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
),
+ # https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
+ "VALUES": lambda self: self.expression(
+ exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
+ ),
}
STATEMENT_PARSERS = {
@@ -399,6 +419,8 @@ class MySQL(Dialect):
NULL_ORDERING_SUPPORTED = False
JOIN_HINTS = False
TABLE_HINTS = True
+ DUPLICATE_KEY_UPDATE_WITH_SET = False
+ QUERY_HINT_SEP = " "
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -445,6 +467,9 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
+ def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
+ return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
+
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
if expression.to.this == exp.DataType.Type.BIGINT:
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 8d35e92..2b77ef9 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -121,7 +121,6 @@ class Oracle(Dialect):
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
- exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
exp.Coalesce: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 766b584..6d78a07 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -183,6 +183,29 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)
+def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
+ """Remove table refs from columns in when statements."""
+ if isinstance(expression, exp.Merge):
+ alias = expression.this.args.get("alias")
+
+ normalize = lambda identifier: Postgres.normalize_identifier(identifier).name
+
+ targets = {normalize(expression.this.this)}
+
+ if alias:
+ targets.add(normalize(alias.this))
+
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.name)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node,
+ copy=False,
+ )
+
+ return expression
+
+
class Postgres(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_large"
@@ -315,6 +338,7 @@ class Postgres(Dialect):
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@@ -352,7 +376,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
- exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
+ exp.Merge: transforms.preprocess([_remove_target_from_merge]),
exp.Pivot: no_pivot_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 24c439b..1721588 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -232,6 +232,7 @@ class Presto(Dialect):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
IS_BOOL_ALLOWED = False
STRUCT_DELIMITER = ("(", ")")
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 87be42c..09edd55 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -86,6 +86,7 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
RENAME_TABLE_WITH_DB = False
+ QUERY_HINTS = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 34e4dd0..19924cd 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -326,6 +326,7 @@ class Snowflake(Dialect):
SINGLE_STRING_INTERVAL = True
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 3720b8d..afe2482 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -173,6 +173,8 @@ class Spark2(Hive):
return pivot_column_names(aggregations, dialect="spark")
class Generator(Hive.Generator):
+ QUERY_HINTS = True
+
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "BYTE",
@@ -203,7 +205,6 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
- exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 519e62a..5ded6df 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -77,6 +77,7 @@ class SQLite(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 67ef76b..33ec7e1 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -8,6 +8,7 @@ class Tableau(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index d9a5417..4e8ffb4 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -121,7 +121,7 @@ class Teradata(Dialect):
exp.Update,
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
- "from": self._parse_from(modifiers=True),
+ "from": self._parse_from(joins=True),
"expressions": self._match(TokenType.SET)
and self._parse_csv(self._parse_equality),
"where": self._parse_where(),
@@ -140,6 +140,7 @@ class Teradata(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ QUERY_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index f671630..92bb755 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -60,10 +60,10 @@ def _format_time_lambda(
assert len(args) == 2
return exp_class(
- this=args[1],
+ this=exp.cast(args[1], "datetime"),
format=exp.Literal.string(
format_time(
- args[0].name,
+ args[0].name.lower(),
{**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.TIME_MAPPING,
@@ -467,6 +467,8 @@ class TSQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
+ LIMIT_IS_TOP = True
+ QUERY_HINTS = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -482,6 +484,7 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
+ exp.Extract: rename_func("DATEPART"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Max: max_or_greatest,
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 630cb65..d7952c1 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -41,11 +41,13 @@ class Context:
def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
+
for other in self.tables.values():
if self._table.columns != other.columns:
raise Exception(f"Columns are different.")
if len(self._table.rows) != len(other.rows):
raise Exception(f"Rows are different.")
+
return self._table
def add_columns(self, *columns: str) -> None:
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 5300224..9f63100 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -100,9 +100,19 @@ def substring(this, start=None, length=None):
@null_if_any
def cast(this, to):
if to == exp.DataType.Type.DATE:
- return datetime.date.fromisoformat(this)
- if to == exp.DataType.Type.DATETIME:
- return datetime.datetime.fromisoformat(this)
+ if isinstance(this, datetime.datetime):
+ return this.date()
+ if isinstance(this, datetime.date):
+ return this
+ if isinstance(this, str):
+ return datetime.date.fromisoformat(this)
+ if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP):
+ if isinstance(this, datetime.datetime):
+ return this
+ if isinstance(this, datetime.date):
+ return datetime.datetime(this.year, this.month, this.day)
+ if isinstance(this, str):
+ return datetime.datetime.fromisoformat(this)
if to == exp.DataType.Type.BOOLEAN:
return bool(this)
if to in exp.DataType.TEXT_TYPES:
@@ -111,7 +121,7 @@ def cast(this, to):
return float(this)
if to in exp.DataType.NUMERIC_TYPES:
return int(this)
- raise NotImplementedError(f"Casting to '{to}' not implemented.")
+ raise NotImplementedError(f"Casting {this} to '{to}' not implemented.")
def ordered(this, desc, nulls_first):
@@ -153,6 +163,7 @@ ENV = {
"CONCAT": null_if_any(lambda *args: "".join(args)),
"SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)),
"CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days),
"DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)),
"DIV": null_if_any(lambda e, this: e / this),
"DOT": null_if_any(lambda e, this: e[this]),
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 34a380e..d2ae79d 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -417,7 +417,9 @@ class Python(Dialect):
exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}",
exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')",
- exp.Is: lambda self, e: self.binary(e, "is"),
+ exp.Is: lambda self, e: self.binary(e, "==")
+ if isinstance(e.this, exp.Literal)
+ else self.binary(e, "is"),
exp.Lambda: _lambda_sql,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index cdb93db..fdf02c8 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -102,13 +102,10 @@ class Expression(metaclass=_Expression):
@property
def hashable_args(self) -> t.Any:
- args = (self.args.get(k) for k in self.arg_types)
-
- return tuple(
- (tuple(_norm_arg(a) for a in arg) if arg else None)
- if type(arg) is list
- else (_norm_arg(arg) if arg is not None and arg is not False else None)
- for arg in args
+ return frozenset(
+ (k, tuple(_norm_arg(a) for a in v) if type(v) is list else _norm_arg(v))
+ for k, v in self.args.items()
+ if not (v is None or v is False or (type(v) is list and not v))
)
def __hash__(self) -> int:
@@ -1304,6 +1301,7 @@ class Delete(Expression):
"where": False,
"returning": False,
"limit": False,
+ "tables": False, # Multiple-Table Syntax (MySQL)
}
def delete(
@@ -1490,9 +1488,7 @@ class Identifier(Expression):
@property
def hashable_args(self) -> t.Any:
- if self.quoted and any(char.isupper() for char in self.this):
- return (self.this, self.quoted)
- return self.this.lower()
+ return (self.this, self.quoted)
@property
def output_name(self) -> str:
@@ -1525,6 +1521,7 @@ class Insert(Expression):
"partition": False,
"alternative": False,
"where": False,
+ "ignore": False,
}
def with_(
@@ -1620,6 +1617,7 @@ class Group(Expression):
"cube": False,
"rollup": False,
"totals": False,
+ "all": False,
}
@@ -4135,9 +4133,9 @@ class DateToDi(Func):
pass
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date
class Date(Func):
- arg_types = {"expressions": True}
- is_var_len_args = True
+ arg_types = {"this": True, "zone": False}
class Day(Func):
@@ -4245,6 +4243,11 @@ class JSONFormat(Func):
_sql_names = ["JSON_FORMAT"]
+# https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of
+class JSONArrayContains(Binary, Predicate, Func):
+ _sql_names = ["JSON_ARRAY_CONTAINS"]
+
+
class Least(Func):
arg_types = {"expressions": False}
is_var_len_args = True
@@ -4475,7 +4478,7 @@ class StrToDate(Func):
class StrToTime(Func):
- arg_types = {"this": True, "format": True}
+ arg_types = {"this": True, "format": True, "zone": False}
# Spark allows unix_timestamp()
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index a41af12..1ce2aaa 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -140,9 +140,21 @@ class Generator:
# Whether or not table hints should be generated
TABLE_HINTS = True
+ # Whether or not query hints should be generated
+ QUERY_HINTS = True
+
+ # What kind of separator to use for query hints
+ QUERY_HINT_SEP = ", "
+
# Whether or not comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
+ # Whether or not to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement
+ DUPLICATE_KEY_UPDATE_WITH_SET = True
+
+ # Whether or not to generate the limit as TOP <value> instead of LIMIT <value>
+ LIMIT_IS_TOP = False
+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@@ -268,6 +280,7 @@ class Generator:
STRICT_STRING_CONCAT = False
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"
+ ESCAPE_LINE_BREAK = False
can_identify: t.Callable[[str, str | bool], bool]
@@ -286,8 +299,6 @@ class Generator:
HEX_END: t.Optional[str] = None
BYTE_START: t.Optional[str] = None
BYTE_END: t.Optional[str] = None
- RAW_START: t.Optional[str] = None
- RAW_END: t.Optional[str] = None
__slots__ = (
"pretty",
@@ -486,7 +497,10 @@ class Generator:
return expression
if key:
- return self.sql(expression.args.get(key))
+ value = expression.args.get(key)
+ if value:
+ return self.sql(value)
+ return ""
if self._cache is not None:
expression_id = hash(expression)
@@ -779,10 +793,7 @@ class Generator:
return this
def rawstring_sql(self, expression: exp.RawString) -> str:
- string = expression.this
- if self.RAW_START:
- return f"{self.RAW_START}{self.escape_str(expression.this)}{self.RAW_END}"
- string = self.escape_str(string.replace("\\", "\\\\"))
+ string = self.escape_str(expression.this.replace("\\", "\\\\"))
return f"{self.QUOTE_START}{string}{self.QUOTE_END}"
def datatypesize_sql(self, expression: exp.DataTypeSize) -> str:
@@ -818,15 +829,14 @@ class Generator:
def delete_sql(self, expression: exp.Delete) -> str:
this = self.sql(expression, "this")
this = f" FROM {this}" if this else ""
- using_sql = (
- f" USING {self.expressions(expression, key='using', sep=', USING ')}"
- if expression.args.get("using")
- else ""
- )
- where_sql = self.sql(expression, "where")
+ using = self.sql(expression, "using")
+ using = f" USING {using}" if using else ""
+ where = self.sql(expression, "where")
returning = self.sql(expression, "returning")
limit = self.sql(expression, "limit")
- sql = f"DELETE{this}{using_sql}{where_sql}{returning}{limit}"
+ tables = self.expressions(expression, key="tables")
+ tables = f" {tables}" if tables else ""
+ sql = f"DELETE{tables}{this}{using}{where}{returning}{limit}"
return self.prepend_ctes(expression, sql)
def drop_sql(self, expression: exp.Drop) -> str:
@@ -867,9 +877,11 @@ class Generator:
return f"{this} FILTER({where})"
def hint_sql(self, expression: exp.Hint) -> str:
- if self.sql(expression, "this"):
+ if not self.QUERY_HINTS:
self.unsupported("Hints are not supported")
- return ""
+ return ""
+
+ return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */"
def index_sql(self, expression: exp.Index) -> str:
unique = "UNIQUE " if expression.args.get("unique") else ""
@@ -1109,6 +1121,8 @@ class Generator:
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
+ ignore = " IGNORE" if expression.args.get("ignore") else ""
+
this = f"{this} {self.sql(expression, 'this')}"
exists = " IF EXISTS" if expression.args.get("exists") else ""
@@ -1120,7 +1134,7 @@ class Generator:
expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}"
conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
- sql = f"INSERT{alternative}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"
+ sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@@ -1147,8 +1161,9 @@ class Generator:
do = "" if expression.args.get("duplicate") else " DO "
nothing = "NOTHING" if expression.args.get("nothing") else ""
expressions = self.expressions(expression, flat=True)
+ set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else ""
if expressions:
- expressions = f"UPDATE SET {expressions}"
+ expressions = f"UPDATE {set_keyword}{expressions}"
return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
def returning_sql(self, expression: exp.Returning) -> str:
@@ -1195,7 +1210,7 @@ class Generator:
hints = f" {hints}" if hints and self.TABLE_HINTS else ""
pivots = self.expressions(expression, key="pivots", sep=" ", flat=True)
pivots = f" {pivots}" if pivots else ""
- joins = self.expressions(expression, key="joins", sep="")
+ joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
@@ -1287,6 +1302,10 @@ class Generator:
def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
+
+ if expression.args.get("all"):
+ return f"{group_by} ALL"
+
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
@@ -1379,7 +1398,7 @@ class Generator:
alias = f" AS {alias}" if alias else ""
return f"LATERAL {this}{alias}"
- def limit_sql(self, expression: exp.Limit) -> str:
+ def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
args = ", ".join(
sql
@@ -1389,7 +1408,7 @@ class Generator:
)
if sql
)
- return f"{this}{self.seg('LIMIT')} {args}"
+ return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
@@ -1441,7 +1460,9 @@ class Generator:
def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
- if self.pretty:
+ if self.ESCAPE_LINE_BREAK:
+ text = text.replace("\n", "\\n")
+ elif self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
@@ -1544,6 +1565,9 @@ class Generator:
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit")
+ # If the limit is generated as TOP, we need to ensure it's not generated twice
+ with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
+
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
limit = exp.Limit(expression=limit.args.get("count"))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
@@ -1551,6 +1575,12 @@ class Generator:
fetch = isinstance(limit, exp.Fetch)
+ offset_limit_modifiers = (
+ self.offset_limit_modifiers(expression, fetch, limit)
+ if with_offset_limit_modifiers
+ else []
+ )
+
return csv(
*sqls,
*[self.sql(join) for join in expression.args.get("joins") or []],
@@ -1561,7 +1591,7 @@ class Generator:
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
self.sql(expression, "order"),
- *self.offset_limit_modifiers(expression, fetch, limit),
+ *offset_limit_modifiers,
*self.after_limit_modifiers(expression),
sep="",
)
@@ -1580,6 +1610,9 @@ class Generator:
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
+ self.sql(expression, "distribute"),
+ self.sql(expression, "sort"),
+ self.sql(expression, "cluster"),
]
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
@@ -1592,6 +1625,13 @@ class Generator:
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
kind = self.sql(expression, "kind").upper()
+ limit = expression.args.get("limit")
+ top = (
+ self.limit_sql(limit, top=True)
+ if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP
+ else ""
+ )
+
expressions = self.expressions(expression)
if kind:
@@ -1618,7 +1658,7 @@ class Generator:
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
- f"SELECT{hint}{distinct}{kind}{expressions}",
+ f"SELECT{top}{hint}{distinct}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@@ -2288,6 +2328,7 @@ class Generator:
sqls: t.Optional[t.List[str]] = None,
flat: bool = False,
indent: bool = True,
+ skip_first: bool = False,
sep: str = ", ",
prefix: str = "",
) -> str:
@@ -2321,7 +2362,7 @@ class Generator:
result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
- return self.indent(result_sql, skip_first=False) if indent else result_sql
+ return self.indent(result_sql, skip_first=skip_first) if indent else result_sql
def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
flat = flat or isinstance(expression.parent, exp.Properties)
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 0fc5f4c..e7cb80b 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -105,6 +105,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.CurrentDate,
exp.Date,
exp.DateAdd,
+ exp.DateFromParts,
exp.DateStrToDate,
exp.DateSub,
exp.DateTrunc,
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 8c3f599..435585c 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -220,25 +220,6 @@ def _expand_group_by(scope: Scope):
group.set("expressions", _expand_positional_references(scope, group.expressions))
expression.set("group", group)
- # group by expressions cannot be simplified, for example
- # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
- # the projection must exactly match the group by key
- groups = set(group.expressions)
- group.meta["final"] = True
-
- for e in expression.selects:
- for node, *_ in e.walk():
- if node in groups:
- e.meta["final"] = True
- break
-
- having = expression.args.get("having")
- if having:
- for node, *_ in having.walk():
- if node in groups:
- having.meta["final"] = True
- break
-
def _expand_order_by(scope: Scope, resolver: Resolver):
order = scope.expression.args.get("order")
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 1a2d82c..e247f58 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -8,6 +8,9 @@ from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
+# Final means that an expression should not be simplified
+FINAL = "final"
+
def simplify(expression):
"""
@@ -27,8 +30,29 @@ def simplify(expression):
generate = cached_generator()
+ # group by expressions cannot be simplified, for example
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
+ # the projection must exactly match the group by key
+ for group in expression.find_all(exp.Group):
+ select = group.parent
+ groups = set(group.expressions)
+ group.meta[FINAL] = True
+
+ for e in select.selects:
+ for node, *_ in e.walk():
+ if node in groups:
+ e.meta[FINAL] = True
+ break
+
+ having = select.args.get("having")
+ if having:
+ for node, *_ in having.walk():
+ if node in groups:
+ having.meta[FINAL] = True
+ break
+
def _simplify(expression, root=True):
- if expression.meta.get("final"):
+ if expression.meta.get(FINAL):
return expression
node = expression
node = rewrite_between(node)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 79e7cac..f7fd6ba 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -737,19 +737,29 @@ class Parser(metaclass=_Parser):
}
QUERY_MODIFIER_PARSERS = {
- "joins": lambda self: list(iter(self._parse_join, None)),
- "laterals": lambda self: list(iter(self._parse_lateral, None)),
- "match": lambda self: self._parse_match_recognize(),
- "where": lambda self: self._parse_where(),
- "group": lambda self: self._parse_group(),
- "having": lambda self: self._parse_having(),
- "qualify": lambda self: self._parse_qualify(),
- "windows": lambda self: self._parse_window_clause(),
- "order": lambda self: self._parse_order(),
- "limit": lambda self: self._parse_limit(),
- "offset": lambda self: self._parse_offset(),
- "locks": lambda self: self._parse_locks(),
- "sample": lambda self: self._parse_table_sample(as_modifier=True),
+ TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()),
+ TokenType.WHERE: lambda self: ("where", self._parse_where()),
+ TokenType.GROUP_BY: lambda self: ("group", self._parse_group()),
+ TokenType.HAVING: lambda self: ("having", self._parse_having()),
+ TokenType.QUALIFY: lambda self: ("qualify", self._parse_qualify()),
+ TokenType.WINDOW: lambda self: ("windows", self._parse_window_clause()),
+ TokenType.ORDER_BY: lambda self: ("order", self._parse_order()),
+ TokenType.LIMIT: lambda self: ("limit", self._parse_limit()),
+ TokenType.FETCH: lambda self: ("limit", self._parse_limit()),
+ TokenType.OFFSET: lambda self: ("offset", self._parse_offset()),
+ TokenType.FOR: lambda self: ("locks", self._parse_locks()),
+ TokenType.LOCK: lambda self: ("locks", self._parse_locks()),
+ TokenType.TABLE_SAMPLE: lambda self: ("sample", self._parse_table_sample(as_modifier=True)),
+ TokenType.USING: lambda self: ("sample", self._parse_table_sample(as_modifier=True)),
+ TokenType.CLUSTER_BY: lambda self: (
+ "cluster",
+ self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
+ ),
+ TokenType.DISTRIBUTE_BY: lambda self: (
+ "distribute",
+ self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
+ ),
+ TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)),
}
SET_PARSERS = {
@@ -1679,6 +1689,7 @@ class Parser(metaclass=_Parser):
def _parse_insert(self) -> exp.Insert:
overwrite = self._match(TokenType.OVERWRITE)
+ ignore = self._match(TokenType.IGNORE)
local = self._match_text_seq("LOCAL")
alternative = None
@@ -1709,6 +1720,7 @@ class Parser(metaclass=_Parser):
returning=self._parse_returning(),
overwrite=overwrite,
alternative=alternative,
+ ignore=ignore,
)
def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]:
@@ -1734,7 +1746,8 @@ class Parser(metaclass=_Parser):
nothing = True
else:
self._match(TokenType.UPDATE)
- expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
+ self._match(TokenType.SET)
+ expressions = self._parse_csv(self._parse_equality)
return self.expression(
exp.OnConflict,
@@ -1805,12 +1818,17 @@ class Parser(metaclass=_Parser):
return self._parse_as_command(self._prev)
def _parse_delete(self) -> exp.Delete:
- self._match(TokenType.FROM)
+ # This handles MySQL's "Multiple-Table Syntax"
+ # https://dev.mysql.com/doc/refman/8.0/en/delete.html
+ tables = None
+ if not self._match(TokenType.FROM, advance=False):
+ tables = self._parse_csv(self._parse_table) or None
return self.expression(
exp.Delete,
- this=self._parse_table(),
- using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
+ tables=tables,
+ this=self._match(TokenType.FROM) and self._parse_table(joins=True),
+ using=self._match(TokenType.USING) and self._parse_table(joins=True),
where=self._parse_where(),
returning=self._parse_returning(),
limit=self._parse_limit(),
@@ -1822,7 +1840,7 @@ class Parser(metaclass=_Parser):
**{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
- "from": self._parse_from(modifiers=True),
+ "from": self._parse_from(joins=True),
"where": self._parse_where(),
"returning": self._parse_returning(),
"limit": self._parse_limit(),
@@ -1875,7 +1893,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
- # Source: https://prestodb.io/docs/current/sql/values.html
+ # https://prestodb.io/docs/current/sql/values.html
return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
def _parse_select(
@@ -1917,7 +1935,7 @@ class Parser(metaclass=_Parser):
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True)
- expressions = self._parse_csv(self._parse_expression)
+ expressions = self._parse_expressions()
this = self.expression(
exp.Select,
@@ -2034,20 +2052,31 @@ class Parser(metaclass=_Parser):
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
if isinstance(this, self.MODIFIABLES):
- for key, parser in self.QUERY_MODIFIER_PARSERS.items():
- expression = parser(self)
-
- if expression:
- if key == "limit":
- offset = expression.args.pop("offset", None)
- if offset:
- this.set("offset", exp.Offset(expression=offset))
- this.set(key, expression)
+ for join in iter(self._parse_join, None):
+ this.append("joins", join)
+ for lateral in iter(self._parse_lateral, None):
+ this.append("laterals", lateral)
+
+ while True:
+ if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False):
+ parser = self.QUERY_MODIFIER_PARSERS[self._curr.token_type]
+ key, expression = parser(self)
+
+ if expression:
+ this.set(key, expression)
+ if key == "limit":
+ offset = expression.args.pop("offset", None)
+ if offset:
+ this.set("offset", exp.Offset(expression=offset))
+ continue
+ break
return this
def _parse_hint(self) -> t.Optional[exp.Hint]:
if self._match(TokenType.HINT):
- hints = self._parse_csv(self._parse_function)
+ hints = []
+ for hint in iter(lambda: self._parse_csv(self._parse_function), []):
+ hints.extend(hint)
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
@@ -2069,18 +2098,13 @@ class Parser(metaclass=_Parser):
)
def _parse_from(
- self, modifiers: bool = False, skip_from_token: bool = False
+ self, joins: bool = False, skip_from_token: bool = False
) -> t.Optional[exp.From]:
if not skip_from_token and not self._match(TokenType.FROM):
return None
- comments = self._prev_comments
- this = self._parse_table()
-
return self.expression(
- exp.From,
- comments=comments,
- this=self._parse_query_modifiers(this) if modifiers else this,
+ exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
)
def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]:
@@ -2091,9 +2115,7 @@ class Parser(metaclass=_Parser):
partition = self._parse_partition_by()
order = self._parse_order()
- measures = (
- self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None
- )
+ measures = self._parse_expressions() if self._match_text_seq("MEASURES") else None
if self._match_text_seq("ONE", "ROW", "PER", "MATCH"):
rows = exp.var("ONE ROW PER MATCH")
@@ -2259,6 +2281,18 @@ class Parser(metaclass=_Parser):
kwargs["on"] = self._parse_conjunction()
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
+ elif not (kind and kind.token_type == TokenType.CROSS):
+ index = self._index
+ joins = self._parse_joins()
+
+ if joins and self._match(TokenType.ON):
+ kwargs["on"] = self._parse_conjunction()
+ elif joins and self._match(TokenType.USING):
+ kwargs["using"] = self._parse_wrapped_id_vars()
+ else:
+ joins = None
+ self._retreat(index)
+ kwargs["this"].set("joins", joins)
return self.expression(exp.Join, **kwargs)
@@ -2363,7 +2397,10 @@ class Parser(metaclass=_Parser):
)
def _parse_table(
- self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ self,
+ schema: bool = False,
+ joins: bool = False,
+ alias_tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@@ -2407,6 +2444,10 @@ class Parser(metaclass=_Parser):
table_sample.set("this", this)
this = table_sample
+ if joins:
+ for join in iter(self._parse_join, None):
+ this.append("joins", join)
+
return this
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
@@ -2507,8 +2548,11 @@ class Parser(metaclass=_Parser):
kind=kind,
)
- def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
- return list(iter(self._parse_pivot, None))
+ def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
+ return list(iter(self._parse_pivot, None)) or None
+
+ def _parse_joins(self) -> t.Optional[t.List[exp.Join]]:
+ return list(iter(self._parse_join, None)) or None
# https://duckdb.org/docs/sql/statements/pivot
def _parse_simplified_pivot(self) -> exp.Pivot:
@@ -2603,6 +2647,9 @@ class Parser(metaclass=_Parser):
elements = defaultdict(list)
+ if self._match(TokenType.ALL):
+ return self.expression(exp.Group, all=True)
+
while True:
expressions = self._parse_csv(self._parse_conjunction)
if expressions:
@@ -3171,7 +3218,7 @@ class Parser(metaclass=_Parser):
if query:
expressions = [query]
else:
- expressions = self._parse_csv(self._parse_expression)
+ expressions = self._parse_expressions()
this = self._parse_query_modifiers(seq_get(expressions, 0))
@@ -3536,11 +3583,7 @@ class Parser(metaclass=_Parser):
return None
expressions = None
- this = self._parse_id_var()
-
- if self._match(TokenType.L_PAREN, advance=False):
- expressions = self._parse_wrapped_id_vars()
-
+ this = self._parse_table(schema=True)
options = self._parse_key_constraint_options()
return self.expression(exp.Reference, this=this, expressions=expressions, options=options)
@@ -3706,21 +3749,27 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.CHARACTER_SET):
to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
elif self._match(TokenType.FORMAT):
- fmt = self._parse_at_time_zone(self._parse_string())
+ fmt_string = self._parse_string()
+ fmt = self._parse_at_time_zone(fmt_string)
if to.this in exp.DataType.TEMPORAL_TYPES:
- return self.expression(
+ this = self.expression(
exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime,
this=this,
format=exp.Literal.string(
format_time(
- fmt.this if fmt else "",
+ fmt_string.this if fmt_string else "",
self.FORMAT_MAPPING or self.TIME_MAPPING,
self.FORMAT_TRIE or self.TIME_TRIE,
)
),
)
+ if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime):
+ this.set("zone", fmt.args["zone"])
+
+ return this
+
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt)
def _parse_concat(self) -> t.Optional[exp.Expression]:
@@ -4223,7 +4272,7 @@ class Parser(metaclass=_Parser):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_expression)
- return self._parse_csv(self._parse_expression)
+ return self._parse_expressions()
def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
@@ -4273,6 +4322,9 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return parse_result
+ def _parse_expressions(self) -> t.List[t.Optional[exp.Expression]]:
+ return self._parse_csv(self._parse_expression)
+
def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_set_operations(
self._parse_expression() if alias else self._parse_conjunction()
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index f246702..07ee739 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -23,9 +23,11 @@ class Plan:
while nodes:
node = nodes.pop()
dag[node] = set()
+
for dep in node.dependencies:
dag[node].add(dep)
nodes.add(dep)
+
self._dag = dag
return self._dag
@@ -128,15 +130,22 @@ class Step:
agg_funcs = tuple(expression.find_all(exp.AggFunc))
if agg_funcs:
aggregations.add(expression)
+
for agg in agg_funcs:
for operand in agg.unnest_operands():
if isinstance(operand, exp.Column):
continue
if operand not in operands:
operands[operand] = next_operand_name()
+
operand.replace(exp.column(operands[operand], quoted=True))
+
return bool(agg_funcs)
+ def set_ops_and_aggs(step):
+ step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
+ step.aggregations = list(aggregations)
+
for e in expression.expressions:
if e.find(exp.AggFunc):
projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
@@ -164,10 +173,7 @@ class Step:
else:
aggregate.condition = having.this
- aggregate.operands = tuple(
- alias(operand, alias_) for operand, alias_ in operands.items()
- )
- aggregate.aggregations = list(aggregations)
+ set_ops_and_aggs(aggregate)
# give aggregates names and replace projections with references to them
aggregate.group = {
@@ -178,13 +184,14 @@ class Step:
for k, v in aggregate.group.items():
intermediate[v] = k
if isinstance(v, exp.Column):
- intermediate[v.alias_or_name] = k
+ intermediate[v.name] = k
for projection in projections:
for node, *_ in projection.walk():
name = intermediate.get(node)
if name:
node.replace(exp.column(name, step.name))
+
if aggregate.condition:
for node, *_ in aggregate.condition.walk():
name = intermediate.get(node) or intermediate.get(node.name)
@@ -197,6 +204,13 @@ class Step:
order = expression.args.get("order")
if order:
+ if isinstance(step, Aggregate):
+ for i, ordered in enumerate(order.expressions):
+ if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)):
+ ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True))
+
+ set_ops_and_aggs(aggregate)
+
sort = Sort()
sort.name = step.name
sort.key = order.expressions
@@ -340,7 +354,10 @@ class Join(Step):
def _to_s(self, indent: str) -> t.List[str]:
lines = []
for name, join in self.joins.items():
- lines.append(f"{indent}{name}: {join['side']}")
+ lines.append(f"{indent}{name}: {join['side'] or 'INNER'}")
+ join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or []))
+ if join_key:
+ lines.append(f"{indent}Key: {join_key}")
if join.get("condition"):
lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 79f7a65..999bde2 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -239,6 +239,7 @@ class TokenType(AutoName):
LOCK = auto()
MAP = auto()
MATCH_RECOGNIZE = auto()
+ MEMBER_OF = auto()
MERGE = auto()
MOD = auto()
NATURAL = auto()
@@ -944,8 +945,6 @@ class Tokenizer(metaclass=_Tokenizer):
char = ""
chars = " "
- word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word
-
if not word:
if self._char in self.SINGLE_TOKENS:
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 33a1bc0..1e6cfc8 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -124,7 +124,9 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
other expressions. This transforms removes the precision from parameterized types in expressions.
"""
for node in expression.find_all(exp.DataType):
- node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
+ node.set(
+ "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeSize)]
+ )
return expression
@@ -215,25 +217,6 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
return expression
-def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
- """Remove table refs from columns in when statements."""
- if isinstance(expression, exp.Merge):
- alias = expression.this.args.get("alias")
- targets = {expression.this.this}
- if alias:
- targets.add(alias.this)
-
- for when in expression.expressions:
- when.transform(
- lambda node: exp.column(node.name)
- if isinstance(node, exp.Column) and node.args.get("table") in targets
- else node,
- copy=False,
- )
-
- return expression
-
-
def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
if (
isinstance(expression, exp.WithinGroup)