summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-04-03 07:31:54 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-04-03 07:31:54 +0000
commitb38d717d5933fdae3fe85c87df7aee9a251fb58e (patch)
tree6db21a44ffea4c832dcab29688bfaf1c1dc124f9 /sqlglot
parentReleasing debian version 11.4.1-1. (diff)
downloadsqlglot-b38d717d5933fdae3fe85c87df7aee9a251fb58e.tar.xz
sqlglot-b38d717d5933fdae3fe85c87df7aee9a251fb58e.zip
Merging upstream version 11.4.5.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/bigquery.py20
-rw-r--r--sqlglot/dialects/clickhouse.py2
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py5
-rw-r--r--sqlglot/dialects/drill.py5
-rw-r--r--sqlglot/dialects/duckdb.py3
-rw-r--r--sqlglot/dialects/hive.py26
-rw-r--r--sqlglot/dialects/mysql.py16
-rw-r--r--sqlglot/dialects/oracle.py51
-rw-r--r--sqlglot/dialects/postgres.py2
-rw-r--r--sqlglot/dialects/snowflake.py11
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/teradata.py3
-rw-r--r--sqlglot/dialects/tsql.py11
-rw-r--r--sqlglot/diff.py24
-rw-r--r--sqlglot/executor/__init__.py2
-rw-r--r--sqlglot/executor/python.py42
-rw-r--r--sqlglot/expressions.py163
-rw-r--r--sqlglot/generator.py80
-rw-r--r--sqlglot/helper.py11
-rw-r--r--sqlglot/optimizer/annotate_types.py12
-rw-r--r--sqlglot/optimizer/canonicalize.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py5
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/lower_identities.py8
-rw-r--r--sqlglot/optimizer/merge_subqueries.py5
-rw-r--r--sqlglot/optimizer/normalize.py104
-rw-r--r--sqlglot/optimizer/optimize_joins.py2
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py88
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py166
-rw-r--r--sqlglot/parser.py109
-rw-r--r--sqlglot/planner.py2
-rw-r--r--sqlglot/schema.py4
-rw-r--r--sqlglot/tokens.py15
37 files changed, 650 insertions, 366 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 10046d1..b53b261 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.4.1"
+__version__ = "11.4.5"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 6a43846..a3f9e6d 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
inline_array_sql,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
rename_func,
@@ -212,6 +213,9 @@ class BigQuery(Dialect):
),
}
+ LOG_BASE_FIRST = False
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@@ -227,6 +231,7 @@ class BigQuery(Dialect):
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
@@ -253,17 +258,19 @@ class BigQuery(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.TINYINT: "INT64",
- exp.DataType.Type.SMALLINT: "INT64",
- exp.DataType.Type.INT: "INT64",
exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
- exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.DOUBLE: "FLOAT64",
- exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.FLOAT: "FLOAT64",
+ exp.DataType.Type.INT: "INT64",
+ exp.DataType.Type.NCHAR: "STRING",
+ exp.DataType.Type.NVARCHAR: "STRING",
+ exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARCHAR: "STRING",
- exp.DataType.Type.NVARCHAR: "STRING",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
@@ -271,6 +278,7 @@ class BigQuery(Dialect):
}
EXPLICIT_UNION = True
+ LIMIT_FETCH = "LIMIT"
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index b54a77d..89e2296 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -68,6 +68,8 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
+ LOG_DEFAULTS_TO_LN = True
+
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 4268f1b..2f93ee7 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -16,6 +16,8 @@ class Databricks(Spark):
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index b267521..839589d 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -430,6 +430,11 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
return rename_func(name)(self, expression)
+def max_or_greatest(self: Generator, expression: exp.Max) -> str:
+ name = "GREATEST" if expression.expressions else "MAX"
+ return rename_func(name)(self, expression)
+
+
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index dc0e519..a33aadc 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import re
import typing as t
from sqlglot import exp, generator, parser, tokens
@@ -102,6 +101,8 @@ class Drill(Dialect):
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -154,4 +155,4 @@ class Drill(Dialect):
}
def normalize_func(self, name: str) -> str:
- return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
+ return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f1d2266..c034208 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -80,6 +80,7 @@ class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "~": TokenType.RLIKE,
":=": TokenType.EQ,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
@@ -212,5 +213,7 @@ class DuckDB(Dialect):
"except": "EXCLUDE",
}
+ LIMIT_FETCH = "LIMIT"
+
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 0110eee..68137ae 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
locate_to_strposition,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
no_recursive_cte_sql,
@@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = {
"DAY": ("DATE_ADD", 1),
}
+TIME_DIFF_FACTOR = {
+ "MILLISECOND": " * 1000",
+ "SECOND": "",
+ "MINUTE": " / 60",
+ "HOUR": " / 3600",
+}
+
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
@@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
+
+ factor = TIME_DIFF_FACTOR.get(unit)
+ if factor is not None:
+ left = self.sql(expression, "this")
+ right = self.sql(expression, "expression")
+ sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
+ return f"({sec_diff}){factor}" if factor else sec_diff
+
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
@@ -237,11 +253,6 @@ class Hive(Dialect):
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": locate_to_strposition,
- "LOG": (
- lambda args: exp.Log.from_arg_list(args)
- if len(args) > 1
- else exp.Ln.from_arg_list(args)
- ),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@@ -261,6 +272,8 @@ class Hive(Dialect):
),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -293,6 +306,7 @@ class Hive(Dialect):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: var_map_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
@@ -338,6 +352,8 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
+ LIMIT_FETCH = "LIMIT"
+
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1e2cfa3..5dfa811 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -3,7 +3,9 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ arrow_json_extract_scalar_sql,
locate_to_strposition,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
no_paren_current_date_sql,
@@ -288,6 +290,8 @@ class MySQL(Dialect):
"SWAPS",
}
+ LOG_DEFAULTS_TO_LN = True
+
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@@ -303,7 +307,13 @@ class MySQL(Dialect):
db = None
else:
position = None
- db = self._parse_id_var() if self._match_text_seq("FROM") else None
+ db = None
+
+ if self._match(TokenType.FROM):
+ db = self._parse_id_var()
+ elif self._match(TokenType.DOT):
+ db = target_id
+ target_id = self._parse_id_var()
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
@@ -384,6 +394,8 @@ class MySQL(Dialect):
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
+ exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
@@ -415,6 +427,8 @@ class MySQL(Dialect):
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}
+ LIMIT_FETCH = "LIMIT"
+
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 7028a04..fad6c4a 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -4,7 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
-from sqlglot.helper import csv, seq_get
+from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
@@ -13,10 +13,6 @@ PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
}
-def _limit_sql(self, expression):
- return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
-
-
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@@ -89,6 +85,20 @@ class Oracle(Dialect):
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
+ def _parse_hint(self) -> t.Optional[exp.Expression]:
+ if self._match(TokenType.HINT):
+ start = self._curr
+ while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
+ self._advance()
+
+ if not self._curr:
+ self.raise_error("Expected */ after HINT")
+
+ end = self._tokens[self._index - 3]
+ return exp.Hint(expressions=[self._find_sql(start, end)])
+
+ return None
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@@ -110,41 +120,20 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
+ exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
- exp.Limit: _limit_sql,
- exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
+ exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
- exp.Substring: rename_func("SUBSTR"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.Trim: trim_sql,
+ exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
- def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
- return csv(
- *sqls,
- *[self.sql(sql) for sql in expression.args.get("joins") or []],
- self.sql(expression, "match"),
- *[self.sql(sql) for sql in expression.args.get("laterals") or []],
- self.sql(expression, "where"),
- self.sql(expression, "group"),
- self.sql(expression, "having"),
- self.sql(expression, "qualify"),
- self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
- if expression.args.get("windows")
- else "",
- self.sql(expression, "distribute"),
- self.sql(expression, "sort"),
- self.sql(expression, "cluster"),
- self.sql(expression, "order"),
- self.sql(expression, "offset"), # offset before limit in oracle
- self.sql(expression, "limit"),
- self.sql(expression, "lock"),
- sep="",
- )
+ LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 5f556a5..31b7e45 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
+ max_or_greatest,
min_or_least,
no_paren_current_date_sql,
no_tablesample_sql,
@@ -315,6 +316,7 @@ class Postgres(Dialect):
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 799e9a6..c50961c 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
+ max_or_greatest,
min_or_least,
rename_func,
timestamptrunc_sql,
@@ -275,6 +276,9 @@ class Snowflake(Dialect):
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
+ exp.DateDiff: lambda self, e: self.func(
+ "DATEDIFF", e.text("unit"), e.expression, e.this
+ ),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@@ -296,6 +300,7 @@ class Snowflake(Dialect):
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
}
@@ -314,12 +319,6 @@ class Snowflake(Dialect):
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
}
- def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
- return self.binary(expression, "ILIKE ANY")
-
- def likeany_sql(self, expression: exp.LikeAny) -> str:
- return self.binary(expression, "LIKE ANY")
-
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index ab78b6e..4091dbb 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -82,6 +82,8 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
+ LIMIT_FETCH = "LIMIT"
+
def cast_sql(self, expression: exp.Cast) -> str:
if expression.to.this == exp.DataType.Type.DATE:
return self.func("DATE", expression.this)
@@ -115,9 +117,6 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
- def fetch_sql(self, expression: exp.Fetch) -> str:
- return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
-
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def groupconcat_sql(self, expression):
this = expression.this
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 8bd0a0c..3d43793 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
-from sqlglot.dialects.dialect import Dialect, min_or_least
+from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
@@ -128,6 +128,7 @@ class Teradata(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 7b52047..8e9b6c3 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -6,6 +6,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ max_or_greatest,
min_or_least,
parse_date_delta,
rename_func,
@@ -269,7 +270,6 @@ class TSQL(Dialect):
# TSQL allows @, # to appear as a variable/identifier prefix
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
- SINGLE_TOKENS.pop("@")
SINGLE_TOKENS.pop("#")
class Parser(parser.Parser):
@@ -313,6 +313,9 @@ class TSQL(Dialect):
TokenType.END: lambda self: self._parse_command(),
}
+ LOG_BASE_FIRST = False
+ LOG_DEFAULTS_TO_LN = True
+
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@@ -435,11 +438,17 @@ class TSQL(Dialect):
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
}
TRANSFORMS.pop(exp.ReturnsProperty)
+ LIMIT_FETCH = "FETCH"
+
+ def offset_sql(self, expression: exp.Offset) -> str:
+ return f"{super().offset_sql(expression)} ROWS"
+
def systemtime_sql(self, expression: exp.SystemTime) -> str:
kind = expression.args["kind"]
if kind == "ALL":
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index dddb9ad..86665e0 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -12,7 +12,7 @@ from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect, expressions as exp
-from sqlglot.helper import ensure_collection
+from sqlglot.helper import ensure_list
@dataclass(frozen=True)
@@ -151,8 +151,8 @@ class ChangeDistiller:
self._source = source
self._target = target
- self._source_index = {id(n[0]): n[0] for n in source.bfs()}
- self._target_index = {id(n[0]): n[0] for n in target.bfs()}
+ self._source_index = {id(n): n for n, *_ in self._source.bfs()}
+ self._target_index = {id(n): n for n, *_ in self._target.bfs()}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
@@ -199,10 +199,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
- id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
+ id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
- id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
+ id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@@ -304,18 +304,18 @@ class ChangeDistiller:
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
- for a in expression.args.values():
- for node in ensure_collection(a):
- if isinstance(node, exp.Expression):
- has_child_exprs = True
- yield from _get_leaves(node)
+ for _, node in expression.iter_expressions():
+ has_child_exprs = True
+ yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
- if type(source) is type(target):
+ if type(source) is type(target) and (
+ not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
+ ):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
- args.extend(ensure_collection(a))
+ args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index a676e7d..a67c155 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -57,7 +57,7 @@ def execute(
for name, table in tables_.mapping.items()
}
- schema = ensure_schema(schema)
+ schema = ensure_schema(schema, dialect=read)
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index d417328..b71cc6a 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -94,13 +94,10 @@ class PythonExecutor:
if source and isinstance(source, exp.Expression):
source = source.name or source.alias
- condition = self.generate(step.condition)
- projections = self.generate_tuple(step.projections)
-
if source is None:
context, table_iter = self.static()
elif source in context:
- if not projections and not condition:
+ if not step.projections and not step.condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
@@ -109,10 +106,12 @@ class PythonExecutor:
else:
context, table_iter = self.scan_table(step)
- if projections:
- sink = self.table(step.projections)
- else:
- sink = self.table(context.columns)
+ return self.context({step.name: self._project_and_filter(context, step, table_iter)})
+
+ def _project_and_filter(self, context, step, table_iter):
+ sink = self.table(step.projections if step.projections else context.columns)
+ condition = self.generate(step.condition)
+ projections = self.generate_tuple(step.projections)
for reader in table_iter:
if len(sink) >= step.limit:
@@ -126,7 +125,7 @@ class PythonExecutor:
else:
sink.append(reader.row)
- return self.context({step.name: sink})
+ return sink
def static(self):
return self.context({}), [RowReader(())]
@@ -185,27 +184,16 @@ class PythonExecutor:
if condition:
source_context.filter(condition)
- condition = self.generate(step.condition)
- projections = self.generate_tuple(step.projections)
-
- if not condition and not projections:
+ if not step.condition and not step.projections:
return source_context
- sink = self.table(step.projections if projections else source_context.columns)
-
- for reader, ctx in source_context:
- if condition and not ctx.eval(condition):
- continue
-
- if projections:
- sink.append(ctx.eval_tuple(projections))
- else:
- sink.append(reader.row)
-
- if len(sink) >= step.limit:
- break
+ sink = self._project_and_filter(
+ source_context,
+ step,
+ (reader for reader, _ in iter(source_context)),
+ )
- if projections:
+ if step.projections:
return self.context({step.name: sink})
else:
return self.context(
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index b9da4cc..f4aae47 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -26,6 +26,7 @@ from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_collection,
+ ensure_list,
seq_get,
split_num_words,
subclasses,
@@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
+ __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
@@ -93,23 +94,31 @@ class Expression(metaclass=_Expression):
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
+ self._hash: t.Optional[int] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other) -> bool:
- return type(self) is type(other) and _norm_args(self) == _norm_args(other)
+ return type(self) is type(other) and hash(self) == hash(other)
- def __hash__(self) -> int:
- return hash(
- (
- self.key,
- tuple(
- (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
- ),
- )
+ @property
+ def hashable_args(self) -> t.Any:
+ args = (self.args.get(k) for k in self.arg_types)
+
+ return tuple(
+ (tuple(_norm_arg(a) for a in arg) if arg else None)
+ if type(arg) is list
+ else (_norm_arg(arg) if arg is not None and arg is not False else None)
+ for arg in args
)
+ def __hash__(self) -> int:
+ if self._hash is not None:
+ return self._hash
+
+ return hash((self.__class__, self.hashable_args))
+
@property
def this(self):
"""
@@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
"""
new = deepcopy(self)
new.parent = self.parent
- for item, parent, _ in new.bfs():
- if isinstance(item, Expression) and parent:
- item.parent = parent
return new
def append(self, arg_key, value):
@@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
self._set_parent(arg_key, value)
def _set_parent(self, arg_key, value):
- if isinstance(value, Expression):
+ if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
- elif isinstance(value, list):
+ elif type(value) is list:
for v in value:
- if isinstance(v, Expression):
+ if hasattr(v, "parent"):
v.parent = self
v.arg_key = arg_key
@@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
+ def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
+ """Yields the key and expression for all arguments, exploding list args."""
+ for k, vs in self.args.items():
+ if type(vs) is list:
+ for v in vs:
+ if hasattr(v, "parent"):
+ yield k, v
+ else:
+ if hasattr(vs, "parent"):
+ yield k, vs
+
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
@@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
- for expression, _, _ in self.walk(bfs=bfs):
+ for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
@@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
+ @property
+ def same_parent(self):
+ """Returns if the parent is the same class as itself."""
+ return type(self.parent) is self.__class__
+
def root(self) -> Expression:
"""
Returns the root expression of this tree.
@@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
if prune and prune(self, parent, key):
return
- for k, v in self.args.items():
- for node in ensure_collection(v):
- if isinstance(node, Expression):
- yield from node.dfs(self, k, prune)
+ for k, v in self.iter_expressions():
+ yield from v.dfs(self, k, prune)
def bfs(self, prune=None):
"""
@@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
if prune and prune(item, parent, key):
continue
- if isinstance(item, Expression):
- for k, v in item.args.items():
- for node in ensure_collection(v):
- if isinstance(node, Expression):
- queue.append((node, item, k))
+ for k, v in item.iter_expressions():
+ queue.append((v, item, k))
def unnest(self):
"""
Returns the first non parenthesis child or self.
"""
expression = self
- while isinstance(expression, Paren):
+ while type(expression) is Paren:
expression = expression.this
return expression
@@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
"""
Returns unnested operands as a tuple.
"""
- return tuple(arg.unnest() for arg in self.args.values() if arg)
+ return tuple(arg.unnest() for _, arg in self.iter_expressions())
def flatten(self, unnest=True):
"""
@@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
- if not isinstance(node, self.__class__):
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
+ if not type(node) is self.__class__:
yield node.unnest() if unnest else node
def __str__(self):
@@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
- for v in ensure_collection(vs)
+ for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
@@ -812,6 +829,10 @@ class Describe(Expression):
arg_types = {"this": True, "kind": False}
+class Pragma(Expression):
+ pass
+
+
class Set(Expression):
arg_types = {"expressions": False}
@@ -1170,6 +1191,7 @@ class Drop(Expression):
"temporary": False,
"materialized": False,
"cascade": False,
+ "constraints": False,
}
@@ -1232,11 +1254,11 @@ class Identifier(Expression):
def quoted(self):
return bool(self.args.get("quoted"))
- def __eq__(self, other):
- return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
-
- def __hash__(self):
- return hash((self.key, self.this.lower()))
+ @property
+ def hashable_args(self) -> t.Any:
+ if self.quoted and any(char.isupper() for char in self.this):
+ return (self.this, self.quoted)
+ return self.this.lower()
@property
def output_name(self):
@@ -1322,15 +1344,9 @@ class Limit(Expression):
class Literal(Condition):
arg_types = {"this": True, "is_string": True}
- def __eq__(self, other):
- return (
- isinstance(other, Literal)
- and self.this == other.this
- and self.args["is_string"] == other.args["is_string"]
- )
-
- def __hash__(self):
- return hash((self.key, self.this, self.args["is_string"]))
+ @property
+ def hashable_args(self) -> t.Any:
+ return (self.this, self.args.get("is_string"))
@classmethod
def number(cls, number) -> Literal:
@@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
instance = _maybe_copy(self, copy)
return Subquery(
this=instance,
- alias=TableAlias(this=to_identifier(alias)),
+ alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
@@ -2058,6 +2074,7 @@ class Lock(Expression):
class Select(Subqueryable):
arg_types = {
"with": False,
+ "kind": False,
"expressions": False,
"hint": False,
"distinct": False,
@@ -3595,6 +3612,21 @@ class Initcap(Func):
pass
+class JSONKeyValue(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
+class JSONObject(Func):
+ arg_types = {
+ "expressions": False,
+ "null_handling": False,
+ "unique_keys": False,
+ "return_type": False,
+ "format_json": False,
+ "encoding": False,
+ }
+
+
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
@@ -3766,8 +3798,10 @@ class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
+# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
+# limit is the number of times a pattern is applied
class RegexpSplit(Func):
- arg_types = {"this": True, "expression": True}
+ arg_types = {"this": True, "expression": True, "limit": False}
class Repeat(Func):
@@ -3967,25 +4001,8 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
-def _norm_args(expression):
- args = {}
-
- for k, arg in expression.args.items():
- if isinstance(arg, list):
- arg = [_norm_arg(a) for a in arg]
- if not arg:
- arg = None
- else:
- arg = _norm_arg(arg)
-
- if arg is not None and arg is not False:
- args[k] = arg
-
- return args
-
-
def _norm_arg(arg):
- return arg.lower() if isinstance(arg, str) else arg
+ return arg.lower() if type(arg) is str else arg
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
@@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
elif isinstance(name, str):
identifier = Identifier(
this=name,
- quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
+ quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
)
else:
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
@@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
- table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
- return Column(this=column_name, table=table_name, **kwargs)
+ return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
def alias_(
@@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
- schema: t.Optional[str | Identifier] = None,
+ db: t.Optional[str | Identifier] = None,
+ catalog: t.Optional[str | Identifier] = None,
quoted: t.Optional[bool] = None,
) -> Column:
"""
@@ -4681,7 +4698,8 @@ def column(
Args:
col: column name
table: table name
- schema: schema name
+ db: db name
+ catalog: catalog name
quoted: whether or not to force quote each part
Returns:
Column: column instance
@@ -4689,7 +4707,8 @@ def column(
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
- schema=to_identifier(schema, quoted=quoted),
+ db=to_identifier(db, quoted=quoted),
+ catalog=to_identifier(catalog, quoted=quoted),
)
@@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
for k, v in expression.args.items():
- is_list_arg = isinstance(v, list)
+ is_list_arg = type(v) is list
child_nodes = v if is_list_arg else [v]
new_child_nodes = []
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index a6f4772..6871dd8 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -110,6 +110,10 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
+ # Whether or not limit and fetch are supported
+ # "ALL", "LIMIT", "FETCH"
+ LIMIT_FETCH = "ALL"
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -209,6 +213,7 @@ class Generator:
"_leading_comma",
"_max_text_width",
"_comments",
+ "_cache",
)
def __init__(
@@ -265,19 +270,28 @@ class Generator:
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
+ self._cache = None
- def generate(self, expression: t.Optional[exp.Expression]) -> str:
+ def generate(
+ self,
+ expression: t.Optional[exp.Expression],
+ cache: t.Optional[t.Dict[int, str]] = None,
+ ) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
expression: the syntax tree.
+ cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
Returns
the SQL string.
"""
+ if cache is not None:
+ self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
+ self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@@ -387,6 +401,12 @@ class Generator:
if key:
return self.sql(expression.args.get(key))
+ if self._cache is not None:
+ expression_id = hash(expression)
+
+ if expression_id in self._cache:
+ return self._cache[expression_id]
+
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@@ -407,7 +427,11 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- return self.maybe_comment(sql, expression) if self._comments and comment else sql
+ sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
+
+ if self._cache is not None:
+ self._cache[expression_id] = sql
+ return sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@@ -697,7 +721,8 @@ class Generator:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
+ constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@@ -733,9 +758,9 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
- text = text.lower() if self.normalize else text
+ text = text.lower() if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
- if expression.args.get("quoted") or should_identify(text, self.identify):
+ if expression.quoted or should_identify(text, self.identify):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@@ -1191,6 +1216,9 @@ class Generator:
)
return f"SET{expressions}"
+ def pragma_sql(self, expression: exp.Pragma) -> str:
+ return f"PRAGMA {self.sql(expression, 'this')}"
+
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
@@ -1299,6 +1327,15 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
+ limit = expression.args.get("limit")
+
+ if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
+ limit = exp.Limit(expression=limit.args.get("count"))
+ elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
+ limit = exp.Fetch(direction="FIRST", count=limit.expression)
+
+ fetch = isinstance(limit, exp.Fetch)
+
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
@@ -1315,14 +1352,16 @@ class Generator:
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
- self.sql(expression, "limit"),
- self.sql(expression, "offset"),
+ self.sql(expression, "offset") if fetch else self.sql(limit),
+ self.sql(limit) if fetch else self.sql(expression, "offset"),
self.sql(expression, "lock"),
self.sql(expression, "sample"),
sep="",
)
def select_sql(self, expression: exp.Select) -> str:
+ kind = expression.args.get("kind")
+ kind = f" AS {kind}" if kind else ""
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@@ -1330,7 +1369,7 @@ class Generator:
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
- f"SELECT{hint}{distinct}{expressions}",
+ f"SELECT{kind}{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@@ -1552,6 +1591,25 @@ class Generator:
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
+ def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
+ return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
+
+ def jsonobject_sql(self, expression: exp.JSONObject) -> str:
+ expressions = self.expressions(expression)
+ null_handling = expression.args.get("null_handling")
+ null_handling = f" {null_handling}" if null_handling else ""
+ unique_keys = expression.args.get("unique_keys")
+ if unique_keys is not None:
+ unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
+ else:
+ unique_keys = ""
+ return_type = self.sql(expression, "return_type")
+ return_type = f" RETURNING {return_type}" if return_type else ""
+ format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
+ encoding = self.sql(expression, "encoding")
+ encoding = f" ENCODING {encoding}" if encoding else ""
+ return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
+
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
@@ -1808,12 +1866,18 @@ class Generator:
def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
+ def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
+ return self.binary(expression, "ILIKE ANY")
+
def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
+ def likeany_sql(self, expression: exp.LikeAny) -> str:
+ return self.binary(expression, "LIKE ANY")
+
def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 6eff974..d44d7dd 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -59,7 +59,7 @@ def ensure_list(value):
"""
if value is None:
return []
- elif isinstance(value, (list, tuple)):
+ if isinstance(value, (list, tuple)):
return list(value)
return [value]
@@ -162,9 +162,7 @@ def camel_to_snake_case(name: str) -> str:
return CAMEL_CASE_PATTERN.sub("_", name).upper()
-def while_changing(
- expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
-) -> E:
+def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
"""
Applies a transformation to a given expression until a fix point is reached.
@@ -176,8 +174,13 @@ def while_changing(
The transformed expression.
"""
while True:
+ for n, *_ in reversed(tuple(expression.walk())):
+ n._hash = hash(n)
start = hash(expression)
expression = func(expression)
+
+ for n, *_ in expression.walk():
+ n._hash = None
if start == hash(expression):
break
return expression
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index c2d6655..99888c6 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,5 @@
from sqlglot import exp
-from sqlglot.helper import ensure_collection, ensure_list, subclasses
+from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -108,6 +108,7 @@ class TypeAnnotator:
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
+ exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
@@ -116,6 +117,7 @@ class TypeAnnotator:
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@@ -296,9 +298,6 @@ class TypeAnnotator:
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
- if not isinstance(expression, exp.Expression):
- return None
-
if expression.type:
return expression # We've already inferred the expression's type
@@ -311,9 +310,8 @@ class TypeAnnotator:
)
def _annotate_args(self, expression):
- for value in expression.args.values():
- for v in ensure_collection(value):
- self._maybe_annotate(v)
+ for _, value in expression.iter_expressions():
+ self._maybe_annotate(value)
return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index c5c780d..ef929ac 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
- and b.type.this != exp.DataType.Type.DATE
+ and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 8e6a520..e0ddfa2 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -1,7 +1,6 @@
from sqlglot import expressions as exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.optimizer.simplify import simplify
def eliminate_joins(expression):
@@ -179,6 +178,4 @@ def join_condition(join):
for condition in conditions:
extract_condition(condition)
- on = simplify(on)
- remaining_condition = None if on == exp.true() else on
- return source_key, join_key, remaining_condition
+ return source_key, join_key, on
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 6f9db82..a39fe96 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -3,7 +3,6 @@ import itertools
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
-from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
@@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
eliminate_subqueries(expression.this)
return expression
- expression = simplify(expression)
root = build_scope(expression)
# Map of alias->Scope|Table
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
index 1cc76cf..fae1726 100644
--- a/sqlglot/optimizer/lower_identities.py
+++ b/sqlglot/optimizer/lower_identities.py
@@ -1,5 +1,4 @@
from sqlglot import exp
-from sqlglot.helper import ensure_collection
def lower_identities(expression):
@@ -40,13 +39,10 @@ def lower_identities(expression):
lower_identities(expression.right)
traversed |= {"this", "expression"}
- for k, v in expression.args.items():
+ for k, v in expression.iter_expressions():
if k in traversed:
continue
-
- for child in ensure_collection(v):
- if isinstance(child, exp.Expression):
- child.transform(_lower, copy=False)
+ v.transform(_lower, copy=False)
return expression
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 70172f4..c3467b2 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -3,7 +3,6 @@ from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.optimizer.simplify import simplify
def merge_subqueries(expression, leave_tables_isolated=False):
@@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if set(exp.column_table_names(where.this)) <= sources:
from_or_join.on(where.this, copy=False)
- from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
- expression.set("where", simplify(expression.args.get("where")))
+ expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index f16f519..f2df230 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,29 +1,63 @@
+from __future__ import annotations
+
+import logging
+import typing as t
+
from sqlglot import exp
+from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+from sqlglot.optimizer.simplify import flatten, uniq_sort
+
+logger = logging.getLogger("sqlglot")
-def normalize(expression, dnf=False, max_distance=128):
+def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
"""
- Rewrite sqlglot AST into conjunctive normal form.
+ Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
- >>> normalize(expression).sql()
+ >>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Args:
- expression (sqlglot.Expression): expression to normalize
- dnf (bool): rewrite in disjunctive normal form instead
- max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ expression: expression to normalize
+ dnf: rewrite in disjunctive normal form instead.
+ max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
- expression = simplify(expression)
+ cache: t.Dict[int, str] = {}
+
+ for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
+ if isinstance(node, exp.Connector):
+ if normalized(node, dnf=dnf):
+ continue
+
+ distance = normalization_distance(node, dnf=dnf)
+
+ if distance > max_distance:
+ logger.info(
+ f"Skipping normalization because distance {distance} exceeds max {max_distance}"
+ )
+ return expression
+
+ root = node is expression
+ original = node.copy()
+ try:
+ node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ except OptimizeError as e:
+ logger.info(e)
+ node.replace(original)
+ if root:
+ return original
+ return expression
+
+ if root:
+ expression = node
- expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
- return simplify(expression)
+ return expression
def normalized(expression, dnf=False):
@@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
- len(list(expression.find_all(exp.Connector))) + 1
+ sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
@@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf):
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
- return [1]
+ return (1,)
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- return [
+ return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
- ]
+ )
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance):
+def distributive_law(expression, dnf, max_distance, cache=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
- if isinstance(expression.unnest(), exp.Connector):
- if normalization_distance(expression, dnf) > max_distance:
- return expression
+ if normalized(expression, dnf=dnf):
+ return expression
- to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+ distance = normalization_distance(expression, dnf=dnf)
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+ if distance > max_distance:
+ raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
@@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func)
- return _distribute(b, a, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
return expression
-def _distribute(a, b, from_func, to_func):
+def _distribute(a, b, from_func, to_func, cache):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- exp.paren(from_func(c, b.left)),
- exp.paren(from_func(c, b.right)),
+ uniq_sort(flatten(from_func(c, b.left)), cache),
+ uniq_sort(flatten(from_func(c, b.right)), cache),
),
)
else:
- a = to_func(from_func(a, b.left), from_func(a, b.right))
-
- return _simplify(a)
-
+ a = to_func(
+ uniq_sort(flatten(from_func(a, b.left)), cache),
+ uniq_sort(flatten(from_func(a, b.right)), cache),
+ )
-def _simplify(node):
- node = uniq_sort(flatten(node))
- exp.replace_children(node, _simplify)
- return node
+ return a
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index dc5ce44..8589657 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,6 +1,5 @@
from sqlglot import exp
from sqlglot.helper import tsort
-from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
@@ -29,7 +28,6 @@ def optimize_joins(expression):
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
- on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index d9d04be..62eb11e 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@@ -43,6 +44,7 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
+ simplify,
)
@@ -78,7 +80,7 @@ def optimize(
Returns:
sqlglot.Expression: optimized expression
"""
- schema = ensure_schema(schema or sqlglot.schema)
+ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 66b3170..5e40cf3 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
- _expand_using(scope, resolver)
+ using_column_tables = _expand_using(scope, resolver)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver)
+ _expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
+ _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
@@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
- joins = list(scope.expression.find_all(exp.Join))
+ joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
- # Mapping of automatically joined column names to source names
+ # Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
for join in joins:
@@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
)
)
- tables = column_tables.setdefault(identifier, [])
+ # Set all values in the dict to None, because we only care about the key ordering
+ tables = column_tables.setdefault(identifier, {})
if table not in tables:
- tables.append(table)
+ tables[table] = None
if join_table not in tables:
- tables.append(join_table)
+ tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
@@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
scope.replace(column, replacement)
+ return column_tables
-def _expand_group_by(scope, resolver):
- group = scope.expression.args.get("group")
- if not group:
- return
+
+def _expand_alias_refs(scope, resolver):
+ selects = {}
# Replace references to select aliases
def transform(node, *_):
@@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
node.set("table", table)
return node
- selects = {s.alias_or_name: s for s in scope.selects}
-
+ if not selects:
+ for s in scope.selects:
+ selects[s.alias_or_name] = s
select = selects.get(node.name)
+
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
@@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
return node
- group.transform(transform, copy=False)
+ for select in scope.expression.selects:
+ select.transform(transform, copy=False)
+
+ for modifier in ("where", "group"):
+ part = scope.expression.args.get(modifier)
+
+ if part:
+ part.transform(transform, copy=False)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
@@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
+
# Determine whether each reference in the order by clause is to a column or an alias.
- for ordered in scope.find_all(exp.Ordered):
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
+ order = scope.expression.args.get("order")
+
+ if order:
+ for ordered in order.expressions:
+ for column in ordered.find_all(exp.Column):
+ if (
+ not column.table
+ and column.parent is not ordered
+ and column.name in resolver.all_columns
+ ):
+ columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
- for having in scope.find_all(exp.Having):
+ having = scope.expression.args.get("having")
+
+ if having:
for column in having.find_all(exp.Column):
if (
not column.table
@@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
-def _expand_stars(scope, resolver):
+def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
+ coalesced_columns = set()
for expression in scope.selects:
if isinstance(expression, exp.Star):
@@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
if columns and "*" not in columns:
table_id = id(table)
for name in columns:
- if name not in except_columns.get(table_id, set()):
+ if name in using_column_tables and table in using_column_tables[name]:
+ if name in coalesced_columns:
+ continue
+
+ coalesced_columns.add(name)
+ tables = using_column_tables[name]
+ coalesce = [exp.column(name, table=table) for table in tables]
+
+ new_selections.append(
+ exp.alias_(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ )
+ )
+ elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 9c0768c..b582eb0 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -160,7 +160,7 @@ class Scope:
Yields:
exp.Expression: nodes
"""
- for expression, _, _ in self.walk(bfs=bfs):
+ for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index f80484d..1ed3ca2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,10 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify=True)
+GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
@@ -28,18 +27,20 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
+ cache = {}
+
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node)
- node = absorb_and_eliminate(node)
+ node = uniq_sort(node, cache, root)
+ node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
- node = simplify_connectors(node)
- node = remove_compliments(node)
+ node = simplify_connectors(node, root)
+ node = remove_compliments(node, root)
node.parent = expression.parent
- node = simplify_literals(node)
+ node = simplify_literals(node, root)
node = simplify_parens(node)
if root:
expression.replace(node)
@@ -70,7 +71,7 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if isinstance(expression.this, exp.Null):
+ if is_null(expression.this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
@@ -78,11 +79,11 @@ def simplify_not(expression):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
- if isinstance(condition, exp.Null):
+ if is_null(condition):
return exp.null()
if always_true(expression.this):
return exp.false()
- if expression.this == FALSE:
+ if is_false(expression.this):
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
@@ -104,42 +105,42 @@ def flatten(expression):
return expression
-def simplify_connectors(expression):
+def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
- if isinstance(expression, exp.Connector):
- if left == right:
+ if left == right:
+ return left
+ if isinstance(expression, exp.And):
+ if is_false(left) or is_false(right):
+ return exp.false()
+ if is_null(left) or is_null(right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
return left
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return exp.false()
- if NULL in (left, right):
- return exp.null()
- if always_true(left) and always_true(right):
- return exp.true()
- if always_true(left):
- return right
- if always_true(right):
- return left
- return _simplify_comparison(expression, left, right)
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return exp.true()
- if left == FALSE and right == FALSE:
- return exp.false()
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return exp.null()
- if left == FALSE:
- return right
- if right == FALSE:
- return left
- return _simplify_comparison(expression, left, right, or_=True)
- return None
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if is_false(left) and is_false(right):
+ return exp.false()
+ if (
+ (is_null(left) and is_null(right))
+ or (is_null(left) and is_false(right))
+ or (is_false(left) and is_null(right))
+ ):
+ return exp.null()
+ if is_false(left):
+ return right
+ if is_false(right):
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
- return _flat_simplify(expression, _simplify_connectors)
+ if isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_connectors, root)
+ return expression
LT_LTE = (exp.LT, exp.LTE)
@@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
-def remove_compliments(expression):
+def remove_compliments(expression, root=True):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -236,23 +237,23 @@ def remove_compliments(expression):
return expression
-def uniq_sort(expression):
+def uniq_sort(expression, cache=None, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e): e for e in flattened}
+ deduped = {GENERATOR.generate(e, cache): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
- expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ expression = result_func(*(e for _, e in sorted(arr)))
break
else:
# we didn't have to sort but maybe we need to dedup
@@ -262,7 +263,7 @@ def uniq_sort(expression):
return expression
-def absorb_and_eliminate(expression):
+def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
@@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
return expression
-def simplify_literals(expression):
- if isinstance(expression, exp.Binary):
- return _flat_simplify(expression, _simplify_binary)
+def simplify_literals(expression, root=True):
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
c = b
not_ = False
- if c == NULL:
+ if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
- if a == NULL:
+ if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
- elif NULL in (a, b):
+ elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
@@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif a.is_string and b.is_string:
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
@@ -381,7 +382,7 @@ def simplify_parens(expression):
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, (exp.Is, exp.Like))
+ or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
):
@@ -400,13 +401,23 @@ def remove_where_true(expression):
def always_true(expression):
- return expression == TRUE or isinstance(expression, exp.Literal)
+ return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
+ expression, exp.Literal
+ )
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
+def is_false(a: exp.Expression) -> bool:
+ return type(a) is exp.Boolean and not a.this
+
+
+def is_null(a: exp.Expression) -> bool:
+ return type(a) is exp.Null
+
+
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
@@ -466,24 +477,27 @@ def boolean_literal(condition):
return exp.true() if condition else exp.false()
-def _flat_simplify(expression, simplifier):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
+def _flat_simplify(expression, simplifier, root=True):
+ if root or not expression.same_parent:
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
- while queue:
- a = queue.popleft()
+ while queue:
+ a = queue.popleft()
- for b in queue:
- result = simplifier(expression, a, b)
+ for b in queue:
+ result = simplifier(expression, a, b)
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
- if len(operands) < size:
- return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index a36251e..8269525 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -19,7 +19,7 @@ from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
-def parse_var_map(args):
+def parse_var_map(args: t.Sequence) -> exp.Expression:
keys = []
values = []
for i in range(0, len(args), 2):
@@ -31,6 +31,11 @@ def parse_var_map(args):
)
+def parse_like(args):
+ like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
+ return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
+
+
def binary_range_parser(
expr_type: t.Type[exp.Expression],
) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]:
@@ -77,6 +82,9 @@ class Parser(metaclass=_Parser):
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
+ "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
+ "IFNULL": exp.Coalesce.from_arg_list,
+ "LIKE": parse_like,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@@ -90,7 +98,6 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": parse_var_map,
- "IFNULL": exp.Coalesce.from_arg_list,
}
NO_PAREN_FUNCTIONS = {
@@ -211,6 +218,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
+ TokenType.FULL,
TokenType.IF,
TokenType.ISNULL,
TokenType.INTERVAL,
@@ -226,8 +234,10 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
+ TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
+ TokenType.PRAGMA,
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
@@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
+ TokenType.FULL,
TokenType.LEFT,
TokenType.NATURAL,
TokenType.OFFSET,
@@ -277,6 +288,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
+ TokenType.GLOB,
TokenType.IDENTIFIER,
TokenType.INDEX,
TokenType.ISNULL,
@@ -461,6 +473,7 @@ class Parser(metaclass=_Parser):
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.MERGE: lambda self: self._parse_merge(),
+ TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
@@ -662,6 +675,8 @@ class Parser(metaclass=_Parser):
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"EXTRACT": lambda self: self._parse_extract(),
+ "JSON_OBJECT": lambda self: self._parse_json_object(),
+ "LOG": lambda self: self._parse_logarithm(),
"POSITION": lambda self: self._parse_position(),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
@@ -719,6 +734,9 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
+ LOG_BASE_FIRST = True
+ LOG_DEFAULTS_TO_LN = False
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1032,6 +1050,7 @@ class Parser(metaclass=_Parser):
temporary=temporary,
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
+ constraints=self._match_text_seq("CONSTRAINTS"),
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@@ -1221,7 +1240,7 @@ class Parser(metaclass=_Parser):
if not identified_property:
break
- for p in ensure_collection(identified_property):
+ for p in ensure_list(identified_property):
properties.append(p)
if properties:
@@ -1704,6 +1723,11 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SELECT):
comments = self._prev_comments
+ kind = (
+ self._match(TokenType.ALIAS)
+ and self._match_texts(("STRUCT", "VALUE"))
+ and self._prev.text
+ )
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@@ -1722,6 +1746,7 @@ class Parser(metaclass=_Parser):
this = self.expression(
exp.Select,
+ kind=kind,
hint=hint,
distinct=distinct,
expressions=expressions,
@@ -2785,7 +2810,6 @@ class Parser(metaclass=_Parser):
this = seq_get(expressions, 0)
self._parse_query_modifiers(this)
- self._match_r_paren()
if isinstance(this, exp.Subqueryable):
this = self._parse_set_operations(
@@ -2794,7 +2818,9 @@ class Parser(metaclass=_Parser):
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
- this = self.expression(exp.Paren, this=this)
+ this = self.expression(exp.Paren, this=self._parse_set_operations(this))
+
+ self._match_r_paren()
if this and comments:
this.comments = comments
@@ -3318,6 +3344,60 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
+ self._match_text_seq("KEY")
+ key = self._parse_field()
+ self._match(TokenType.COLON)
+ self._match_text_seq("VALUE")
+ value = self._parse_field()
+ if not key and not value:
+ return None
+ return self.expression(exp.JSONKeyValue, this=key, expression=value)
+
+ def _parse_json_object(self) -> exp.Expression:
+ expressions = self._parse_csv(self._parse_json_key_value)
+
+ null_handling = None
+ if self._match_text_seq("NULL", "ON", "NULL"):
+ null_handling = "NULL ON NULL"
+ elif self._match_text_seq("ABSENT", "ON", "NULL"):
+ null_handling = "ABSENT ON NULL"
+
+ unique_keys = None
+ if self._match_text_seq("WITH", "UNIQUE"):
+ unique_keys = True
+ elif self._match_text_seq("WITHOUT", "UNIQUE"):
+ unique_keys = False
+
+ self._match_text_seq("KEYS")
+
+ return_type = self._match_text_seq("RETURNING") and self._parse_type()
+ format_json = self._match_text_seq("FORMAT", "JSON")
+ encoding = self._match_text_seq("ENCODING") and self._parse_var()
+
+ return self.expression(
+ exp.JSONObject,
+ expressions=expressions,
+ null_handling=null_handling,
+ unique_keys=unique_keys,
+ return_type=return_type,
+ format_json=format_json,
+ encoding=encoding,
+ )
+
+ def _parse_logarithm(self) -> exp.Expression:
+ # Default argument order is base, expression
+ args = self._parse_csv(self._parse_range)
+
+ if len(args) > 1:
+ if not self.LOG_BASE_FIRST:
+ args.reverse()
+ return exp.Log.from_arg_list(args)
+
+ return self.expression(
+ exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
+ )
+
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@@ -3654,7 +3734,7 @@ class Parser(metaclass=_Parser):
return parse_result
def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
- return self._parse_select() or self._parse_expression()
+ return self._parse_select() or self._parse_set_operations(self._parse_expression())
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
return self._parse_set_operations(
@@ -3741,6 +3821,8 @@ class Parser(metaclass=_Parser):
expression = self._parse_foreign_key()
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
expression = self._parse_primary_key()
+ else:
+ expression = None
return self.expression(exp.AddConstraint, this=this, expression=expression)
@@ -3799,12 +3881,15 @@ class Parser(metaclass=_Parser):
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
if parser:
- return self.expression(
- exp.AlterTable,
- this=this,
- exists=exists,
- actions=ensure_list(parser(self)),
- )
+ actions = ensure_list(parser(self))
+
+ if not self._curr:
+ return self.expression(
+ exp.AlterTable,
+ this=this,
+ exists=exists,
+ actions=actions,
+ )
return self._parse_as_command(start)
def _parse_merge(self) -> exp.Expression:
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 40df39f..5fd96ef 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -175,7 +175,7 @@ class Step:
}
for projection in projections:
for i, e in aggregate.group.items():
- for child, _, _ in projection.walk():
+ for child, *_ in projection.walk():
if child == e:
child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f5d9f2b..8e39c7f 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -306,11 +306,11 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return self._type_mapping_cache[schema_type]
-def ensure_schema(schema: t.Any) -> Schema:
+def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
if isinstance(schema, Schema):
return schema
- return MappingSchema(schema)
+ return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index eb3c08f..e5b44e7 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -252,6 +252,7 @@ class TokenType(AutoName):
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
+ PRAGMA = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
@@ -346,7 +347,8 @@ class Token:
self.token_type = token_type
self.text = text
self.line = line
- self.col = max(col - len(text), 1)
+ self.col = col - len(text)
+ self.col = self.col if self.col > 1 else 1
self.comments = comments
def __repr__(self) -> str:
@@ -586,6 +588,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PARTITIONED_BY": TokenType.PARTITION_BY,
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
+ "PRAGMA": TokenType.PRAGMA,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
@@ -654,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
"INT8": TokenType.BIGINT,
+ "DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
@@ -714,7 +718,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VACUUM": TokenType.COMMAND,
}
- WHITE_SPACE: t.Dict[str, TokenType] = {
+ WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
" ": TokenType.SPACE,
"\t": TokenType.SPACE,
"\n": TokenType.BREAK,
@@ -813,11 +817,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end]
return ""
- def _line_break(self, char: t.Optional[str]) -> bool:
- return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore
-
def _advance(self, i: int = 1) -> None:
- if self._line_break(self._char):
+ if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._set_new_line()
self._col += i
@@ -939,7 +940,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
- while not self._end and not self._line_break(self._peek):
+ while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
self._comments.append(self._text[comment_start_size:]) # type: ignore