summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot.svg21
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/bigquery.py20
-rw-r--r--sqlglot/dialects/clickhouse.py2
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py5
-rw-r--r--sqlglot/dialects/drill.py5
-rw-r--r--sqlglot/dialects/duckdb.py3
-rw-r--r--sqlglot/dialects/hive.py26
-rw-r--r--sqlglot/dialects/mysql.py16
-rw-r--r--sqlglot/dialects/oracle.py51
-rw-r--r--sqlglot/dialects/postgres.py2
-rw-r--r--sqlglot/dialects/snowflake.py11
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/teradata.py3
-rw-r--r--sqlglot/dialects/tsql.py11
-rw-r--r--sqlglot/diff.py24
-rw-r--r--sqlglot/executor/__init__.py2
-rw-r--r--sqlglot/executor/python.py42
-rw-r--r--sqlglot/expressions.py163
-rw-r--r--sqlglot/generator.py80
-rw-r--r--sqlglot/helper.py11
-rw-r--r--sqlglot/optimizer/annotate_types.py12
-rw-r--r--sqlglot/optimizer/canonicalize.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py5
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py2
-rw-r--r--sqlglot/optimizer/lower_identities.py8
-rw-r--r--sqlglot/optimizer/merge_subqueries.py5
-rw-r--r--sqlglot/optimizer/normalize.py104
-rw-r--r--sqlglot/optimizer/optimize_joins.py2
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py88
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py166
-rw-r--r--sqlglot/parser.py109
-rw-r--r--sqlglot/planner.py2
-rw-r--r--sqlglot/schema.py4
-rw-r--r--sqlglot/tokens.py15
38 files changed, 671 insertions, 366 deletions
diff --git a/sqlglot.svg b/sqlglot.svg
new file mode 100644
index 0000000..e57fe78
--- /dev/null
+++ b/sqlglot.svg
@@ -0,0 +1,21 @@
+<svg width="221" height="72" viewBox="0 0 221 72" fill="none" xmlns="http://www.w3.org/2000/svg">
+<g clip-path="url(#clip0_623_487)">
+<path d="M47.4749 42.4749C48.8417 41.108 48.8417 38.892 47.4749 37.5251C46.108 36.1583 43.892 36.1583 42.5251 37.5251L36.5251 43.5251C35.1583 44.892 35.1583 47.108 36.5251 48.4749L42.5251 54.4749C43.892 55.8417 46.108 55.8417 47.4749 54.4749C48.8417 53.108 48.8417 50.892 47.4749 49.5251L47.4497 49.5L56.6667 49.5C57.9553 49.5 59 50.5447 59 51.8333C59 56.3437 55.3437 60 50.8333 60H24.5C17.5964 60 12 54.4036 12 47.5V46C12 44.067 10.433 42.5 8.5 42.5C6.56701 42.5 5 44.067 5 46V47.5C5 58.2695 13.7304 67 24.5 67L50.8333 67C59.2096 67 66 60.2096 66 51.8333C66 46.6787 61.8213 42.5 56.6667 42.5L47.4497 42.5L47.4749 42.4749Z" fill="#004EDB"/>
+<path d="M46.5 13C53.4036 13 59 18.5964 59 25.5V27C59 28.933 60.567 30.5 62.5 30.5C64.433 30.5 66 28.933 66 27V25.5C66 14.7304 57.2696 6 46.5 6L20.1667 6C11.7903 6 5 12.7903 5 21.1667C5 26.3213 9.17868 30.5 14.3333 30.5H24.5503L24.5252 30.5251C23.1584 31.892 23.1584 34.108 24.5252 35.4749C25.892 36.8417 28.1081 36.8417 29.4749 35.4749L35.4749 29.4749C36.1643 28.7855 36.5349 27.8399 36.4975 26.8657C36.4601 25.8915 36.018 24.9771 35.2778 24.3426L28.2778 18.3426C26.8102 17.0846 24.6006 17.2546 23.3427 18.7222C22.1356 20.1304 22.2432 22.2217 23.5504 23.5L14.3333 23.5C13.0447 23.5 12 22.4553 12 21.1667C12 16.6563 15.6563 13 20.1667 13L46.5 13Z" fill="#004EDB"/>
+<circle cx="45" cy="27" r="7" fill="#0066FF"/>
+<circle cx="26" cy="46" r="7" fill="#66AFFF"/>
+<rect x="38" y="20" width="14" height="14" rx="4" fill="#0066FF"/>
+</g>
+<path d="M95.144 29.336C94.568 26.564 92.264 23 86.504 23C81.788 23 78.08 26.6 78.08 30.596C78.08 34.412 80.672 36.788 84.308 37.58L87.944 38.372C90.5 38.912 91.76 40.496 91.76 42.332C91.76 44.564 90.032 46.4 86.504 46.4C82.652 46.4 80.636 43.772 80.384 40.928L77 42.008C77.468 45.644 80.384 49.604 86.54 49.604C91.976 49.604 95.36 46.004 95.36 42.044C95.36 38.48 92.984 35.816 88.736 34.88L84.92 34.052C82.76 33.584 81.644 32.18 81.644 30.344C81.644 27.968 83.696 26.096 86.576 26.096C90.14 26.096 91.688 28.616 91.976 30.452L95.144 29.336Z" fill="black"/>
+<path d="M98.4161 36.284C98.4161 44.816 104.824 49.604 111.232 49.604C113.572 49.604 115.912 48.992 117.928 47.768L121.024 51.188L123.436 49.064L120.448 45.716C122.608 43.484 124.048 40.316 124.048 36.284C124.048 27.752 117.64 23 111.232 23C104.824 23 98.4161 27.752 98.4161 36.284ZM102.016 36.284C102.016 29.624 106.48 26.24 111.232 26.24C115.984 26.24 120.448 29.624 120.448 36.284C120.448 39.236 119.548 41.54 118.18 43.196L113.716 38.156L111.268 40.28L115.696 45.248C114.328 46.004 112.816 46.364 111.232 46.364C106.48 46.364 102.016 42.944 102.016 36.284Z" fill="black"/>
+<path d="M144.487 49.064V45.752H132.427V23.54H128.899V49.064H144.487Z" fill="black"/>
+<path d="M174.544 49.064V35.42H161.548V40.244H169.216C168.892 41.684 167.128 44.42 162.916 44.42C158.776 44.42 155.248 41.648 155.248 36.32C155.248 30.632 159.316 28.328 162.772 28.328C167.02 28.328 168.712 31.208 169.108 32.792L174.58 30.884C173.464 27.32 170.08 23 162.772 23C155.572 23 149.488 28.292 149.488 36.32C149.488 44.384 155.32 49.604 162.412 49.604C166.048 49.604 168.46 48.092 169.576 46.472L169.936 49.064H174.544Z" fill="black"/>
+<path d="M184.447 49.064V23H178.975V49.064H184.447Z" fill="black"/>
+<path d="M197.484 44.564C195.432 44.564 193.452 43.088 193.452 40.1C193.452 37.076 195.432 35.672 197.484 35.672C199.572 35.672 201.516 37.076 201.516 40.1C201.516 43.124 199.572 44.564 197.484 44.564ZM197.484 30.632C192.156 30.632 187.98 34.556 187.98 40.1C187.98 45.644 192.156 49.604 197.484 49.604C202.848 49.604 206.988 45.644 206.988 40.1C206.988 34.556 202.848 30.632 197.484 30.632Z" fill="black"/>
+<path d="M217.027 25.952H212.131V28.256C212.131 29.912 211.231 31.172 209.359 31.172H208.459V35.96H211.627V43.628C211.627 47.192 213.895 49.388 217.603 49.388C219.331 49.388 220.231 48.992 220.447 48.884V44.348C220.123 44.42 219.583 44.528 219.007 44.528C217.819 44.528 217.027 44.132 217.027 42.656V35.96H220.519V31.172H217.027V25.952Z" fill="black"/>
+<defs>
+<clipPath id="clip0_623_487">
+<rect width="72" height="72" fill="white"/>
+</clipPath>
+</defs>
+</svg> \ No newline at end of file
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 10046d1..b53b261 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.4.1"
+__version__ = "11.4.5"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 6a43846..a3f9e6d 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
inline_array_sql,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
rename_func,
@@ -212,6 +213,9 @@ class BigQuery(Dialect):
),
}
+ LOG_BASE_FIRST = False
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
@@ -227,6 +231,7 @@ class BigQuery(Dialect):
exp.GroupConcat: rename_func("STRING_AGG"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(
[_unqualify_unnest], transforms.delegate("select_sql")
@@ -253,17 +258,19 @@ class BigQuery(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.TINYINT: "INT64",
- exp.DataType.Type.SMALLINT: "INT64",
- exp.DataType.Type.INT: "INT64",
exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
- exp.DataType.Type.FLOAT: "FLOAT64",
exp.DataType.Type.DOUBLE: "FLOAT64",
- exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.FLOAT: "FLOAT64",
+ exp.DataType.Type.INT: "INT64",
+ exp.DataType.Type.NCHAR: "STRING",
+ exp.DataType.Type.NVARCHAR: "STRING",
+ exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.VARCHAR: "STRING",
- exp.DataType.Type.NVARCHAR: "STRING",
}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
@@ -271,6 +278,7 @@ class BigQuery(Dialect):
}
EXPLICIT_UNION = True
+ LIMIT_FETCH = "LIMIT"
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index b54a77d..89e2296 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -68,6 +68,8 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
+ LOG_DEFAULTS_TO_LN = True
+
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 4268f1b..2f93ee7 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -16,6 +16,8 @@ class Databricks(Spark):
"DATEDIFF": parse_date_delta(exp.DateDiff),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(Spark.Generator):
TRANSFORMS = {
**Spark.Generator.TRANSFORMS, # type: ignore
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index b267521..839589d 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -430,6 +430,11 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
return rename_func(name)(self, expression)
+def max_or_greatest(self: Generator, expression: exp.Max) -> str:
+ name = "GREATEST" if expression.expressions else "MAX"
+ return rename_func(name)(self, expression)
+
+
def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
cond = expression.this
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index dc0e519..a33aadc 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import re
import typing as t
from sqlglot import exp, generator, parser, tokens
@@ -102,6 +101,8 @@ class Drill(Dialect):
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -154,4 +155,4 @@ class Drill(Dialect):
}
def normalize_func(self, name: str) -> str:
- return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
+ return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f1d2266..c034208 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -80,6 +80,7 @@ class DuckDB(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "~": TokenType.RLIKE,
":=": TokenType.EQ,
"ATTACH": TokenType.COMMAND,
"BINARY": TokenType.VARBINARY,
@@ -212,5 +213,7 @@ class DuckDB(Dialect):
"except": "EXCLUDE",
}
+ LIMIT_FETCH = "LIMIT"
+
def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 0110eee..68137ae 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
locate_to_strposition,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
no_recursive_cte_sql,
@@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = {
"DAY": ("DATE_ADD", 1),
}
+TIME_DIFF_FACTOR = {
+ "MILLISECOND": " * 1000",
+ "SECOND": "",
+ "MINUTE": " / 60",
+ "HOUR": " / 3600",
+}
+
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
@@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
+
+ factor = TIME_DIFF_FACTOR.get(unit)
+ if factor is not None:
+ left = self.sql(expression, "this")
+ right = self.sql(expression, "expression")
+ sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
+ return f"({sec_diff}){factor}" if factor else sec_diff
+
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
@@ -237,11 +253,6 @@ class Hive(Dialect):
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": locate_to_strposition,
- "LOG": (
- lambda args: exp.Log.from_arg_list(args)
- if len(args) > 1
- else exp.Ln.from_arg_list(args)
- ),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@@ -261,6 +272,8 @@ class Hive(Dialect):
),
}
+ LOG_DEFAULTS_TO_LN = True
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -293,6 +306,7 @@ class Hive(Dialect):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.Map: var_map_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
@@ -338,6 +352,8 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
+ LIMIT_FETCH = "LIMIT"
+
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1e2cfa3..5dfa811 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -3,7 +3,9 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ arrow_json_extract_scalar_sql,
locate_to_strposition,
+ max_or_greatest,
min_or_least,
no_ilike_sql,
no_paren_current_date_sql,
@@ -288,6 +290,8 @@ class MySQL(Dialect):
"SWAPS",
}
+ LOG_DEFAULTS_TO_LN = True
+
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@@ -303,7 +307,13 @@ class MySQL(Dialect):
db = None
else:
position = None
- db = self._parse_id_var() if self._match_text_seq("FROM") else None
+ db = None
+
+ if self._match(TokenType.FROM):
+ db = self._parse_id_var()
+ elif self._match(TokenType.DOT):
+ db = target_id
+ target_id = self._parse_id_var()
channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
@@ -384,6 +394,8 @@ class MySQL(Dialect):
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
+ exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
@@ -415,6 +427,8 @@ class MySQL(Dialect):
exp.TransientProperty: exp.Properties.Location.UNSUPPORTED,
}
+ LIMIT_FETCH = "LIMIT"
+
def show_sql(self, expression):
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 7028a04..fad6c4a 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -4,7 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
-from sqlglot.helper import csv, seq_get
+from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
@@ -13,10 +13,6 @@ PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
}
-def _limit_sql(self, expression):
- return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
-
-
def _parse_xml_table(self) -> exp.XMLTable:
this = self._parse_string()
@@ -89,6 +85,20 @@ class Oracle(Dialect):
column.set("join_mark", self._match(TokenType.JOIN_MARKER))
return column
+ def _parse_hint(self) -> t.Optional[exp.Expression]:
+ if self._match(TokenType.HINT):
+ start = self._curr
+ while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH):
+ self._advance()
+
+ if not self._curr:
+ self.raise_error("Expected */ after HINT")
+
+ end = self._tokens[self._index - 3]
+ return exp.Hint(expressions=[self._find_sql(start, end)])
+
+ return None
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
@@ -110,41 +120,20 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
+ exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
- exp.Limit: _limit_sql,
- exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
+ exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
- exp.Substring: rename_func("SUBSTR"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.Trim: trim_sql,
+ exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
- def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
- return csv(
- *sqls,
- *[self.sql(sql) for sql in expression.args.get("joins") or []],
- self.sql(expression, "match"),
- *[self.sql(sql) for sql in expression.args.get("laterals") or []],
- self.sql(expression, "where"),
- self.sql(expression, "group"),
- self.sql(expression, "having"),
- self.sql(expression, "qualify"),
- self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
- if expression.args.get("windows")
- else "",
- self.sql(expression, "distribute"),
- self.sql(expression, "sort"),
- self.sql(expression, "cluster"),
- self.sql(expression, "order"),
- self.sql(expression, "offset"), # offset before limit in oracle
- self.sql(expression, "limit"),
- self.sql(expression, "lock"),
- sep="",
- )
+ LIMIT_FETCH = "FETCH"
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 5f556a5..31b7e45 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
format_time_lambda,
+ max_or_greatest,
min_or_least,
no_paren_current_date_sql,
no_tablesample_sql,
@@ -315,6 +316,7 @@ class Postgres(Dialect):
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 799e9a6..c50961c 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
+ max_or_greatest,
min_or_least,
rename_func,
timestamptrunc_sql,
@@ -275,6 +276,9 @@ class Snowflake(Dialect):
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
+ exp.DateDiff: lambda self, e: self.func(
+ "DATEDIFF", e.text("unit"), e.expression, e.this
+ ),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@@ -296,6 +300,7 @@ class Snowflake(Dialect):
exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
exp.UnixToTime: _unix_to_time_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
}
@@ -314,12 +319,6 @@ class Snowflake(Dialect):
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
}
- def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
- return self.binary(expression, "ILIKE ANY")
-
- def likeany_sql(self, expression: exp.LikeAny) -> str:
- return self.binary(expression, "LIKE ANY")
-
def except_op(self, expression):
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT with All is not supported in Snowflake")
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index ab78b6e..4091dbb 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -82,6 +82,8 @@ class SQLite(Dialect):
exp.TryCast: no_trycast_sql,
}
+ LIMIT_FETCH = "LIMIT"
+
def cast_sql(self, expression: exp.Cast) -> str:
if expression.to.this == exp.DataType.Type.DATE:
return self.func("DATE", expression.this)
@@ -115,9 +117,6 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
- def fetch_sql(self, expression: exp.Fetch) -> str:
- return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
-
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def groupconcat_sql(self, expression):
this = expression.this
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 8bd0a0c..3d43793 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens
-from sqlglot.dialects.dialect import Dialect, min_or_least
+from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
from sqlglot.tokens import TokenType
@@ -128,6 +128,7 @@ class Teradata(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 7b52047..8e9b6c3 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -6,6 +6,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ max_or_greatest,
min_or_least,
parse_date_delta,
rename_func,
@@ -269,7 +270,6 @@ class TSQL(Dialect):
# TSQL allows @, # to appear as a variable/identifier prefix
SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
- SINGLE_TOKENS.pop("@")
SINGLE_TOKENS.pop("#")
class Parser(parser.Parser):
@@ -313,6 +313,9 @@ class TSQL(Dialect):
TokenType.END: lambda self: self._parse_command(),
}
+ LOG_BASE_FIRST = False
+ LOG_DEFAULTS_TO_LN = True
+
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@@ -435,11 +438,17 @@ class TSQL(Dialect):
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
+ exp.Max: max_or_greatest,
exp.Min: min_or_least,
}
TRANSFORMS.pop(exp.ReturnsProperty)
+ LIMIT_FETCH = "FETCH"
+
+ def offset_sql(self, expression: exp.Offset) -> str:
+ return f"{super().offset_sql(expression)} ROWS"
+
def systemtime_sql(self, expression: exp.SystemTime) -> str:
kind = expression.args["kind"]
if kind == "ALL":
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index dddb9ad..86665e0 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -12,7 +12,7 @@ from dataclasses import dataclass
from heapq import heappop, heappush
from sqlglot import Dialect, expressions as exp
-from sqlglot.helper import ensure_collection
+from sqlglot.helper import ensure_list
@dataclass(frozen=True)
@@ -151,8 +151,8 @@ class ChangeDistiller:
self._source = source
self._target = target
- self._source_index = {id(n[0]): n[0] for n in source.bfs()}
- self._target_index = {id(n[0]): n[0] for n in target.bfs()}
+ self._source_index = {id(n): n for n, *_ in self._source.bfs()}
+ self._target_index = {id(n): n for n, *_ in self._target.bfs()}
self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes)
self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values())
self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
@@ -199,10 +199,10 @@ class ChangeDistiller:
matching_set = leaves_matching_set.copy()
ordered_unmatched_source_nodes = {
- id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes
+ id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes
}
ordered_unmatched_target_nodes = {
- id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes
+ id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes
}
for source_node_id in ordered_unmatched_source_nodes:
@@ -304,18 +304,18 @@ class ChangeDistiller:
def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]:
has_child_exprs = False
- for a in expression.args.values():
- for node in ensure_collection(a):
- if isinstance(node, exp.Expression):
- has_child_exprs = True
- yield from _get_leaves(node)
+ for _, node in expression.iter_expressions():
+ has_child_exprs = True
+ yield from _get_leaves(node)
if not has_child_exprs:
yield expression
def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
- if type(source) is type(target):
+ if type(source) is type(target) and (
+ not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent)
+ ):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
- args.extend(ensure_collection(a))
+ args.extend(ensure_list(a))
return [a for a in args if isinstance(a, exp.Expression)]
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index a676e7d..a67c155 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -57,7 +57,7 @@ def execute(
for name, table in tables_.mapping.items()
}
- schema = ensure_schema(schema)
+ schema = ensure_schema(schema, dialect=read)
if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args:
raise ExecuteError("Tables must support the same table args as schema")
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index d417328..b71cc6a 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -94,13 +94,10 @@ class PythonExecutor:
if source and isinstance(source, exp.Expression):
source = source.name or source.alias
- condition = self.generate(step.condition)
- projections = self.generate_tuple(step.projections)
-
if source is None:
context, table_iter = self.static()
elif source in context:
- if not projections and not condition:
+ if not step.projections and not step.condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
@@ -109,10 +106,12 @@ class PythonExecutor:
else:
context, table_iter = self.scan_table(step)
- if projections:
- sink = self.table(step.projections)
- else:
- sink = self.table(context.columns)
+ return self.context({step.name: self._project_and_filter(context, step, table_iter)})
+
+ def _project_and_filter(self, context, step, table_iter):
+ sink = self.table(step.projections if step.projections else context.columns)
+ condition = self.generate(step.condition)
+ projections = self.generate_tuple(step.projections)
for reader in table_iter:
if len(sink) >= step.limit:
@@ -126,7 +125,7 @@ class PythonExecutor:
else:
sink.append(reader.row)
- return self.context({step.name: sink})
+ return sink
def static(self):
return self.context({}), [RowReader(())]
@@ -185,27 +184,16 @@ class PythonExecutor:
if condition:
source_context.filter(condition)
- condition = self.generate(step.condition)
- projections = self.generate_tuple(step.projections)
-
- if not condition and not projections:
+ if not step.condition and not step.projections:
return source_context
- sink = self.table(step.projections if projections else source_context.columns)
-
- for reader, ctx in source_context:
- if condition and not ctx.eval(condition):
- continue
-
- if projections:
- sink.append(ctx.eval_tuple(projections))
- else:
- sink.append(reader.row)
-
- if len(sink) >= step.limit:
- break
+ sink = self._project_and_filter(
+ source_context,
+ step,
+ (reader for reader, _ in iter(source_context)),
+ )
- if projections:
+ if step.projections:
return self.context({step.name: sink})
else:
return self.context(
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index b9da4cc..f4aae47 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -26,6 +26,7 @@ from sqlglot.helper import (
AutoName,
camel_to_snake_case,
ensure_collection,
+ ensure_list,
seq_get,
split_num_words,
subclasses,
@@ -84,7 +85,7 @@ class Expression(metaclass=_Expression):
key = "expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta")
+ __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash")
def __init__(self, **args: t.Any):
self.args: t.Dict[str, t.Any] = args
@@ -93,23 +94,31 @@ class Expression(metaclass=_Expression):
self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
self._meta: t.Optional[t.Dict[str, t.Any]] = None
+ self._hash: t.Optional[int] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
def __eq__(self, other) -> bool:
- return type(self) is type(other) and _norm_args(self) == _norm_args(other)
+ return type(self) is type(other) and hash(self) == hash(other)
- def __hash__(self) -> int:
- return hash(
- (
- self.key,
- tuple(
- (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
- ),
- )
+ @property
+ def hashable_args(self) -> t.Any:
+ args = (self.args.get(k) for k in self.arg_types)
+
+ return tuple(
+ (tuple(_norm_arg(a) for a in arg) if arg else None)
+ if type(arg) is list
+ else (_norm_arg(arg) if arg is not None and arg is not False else None)
+ for arg in args
)
+ def __hash__(self) -> int:
+ if self._hash is not None:
+ return self._hash
+
+ return hash((self.__class__, self.hashable_args))
+
@property
def this(self):
"""
@@ -247,9 +256,6 @@ class Expression(metaclass=_Expression):
"""
new = deepcopy(self)
new.parent = self.parent
- for item, parent, _ in new.bfs():
- if isinstance(item, Expression) and parent:
- item.parent = parent
return new
def append(self, arg_key, value):
@@ -277,12 +283,12 @@ class Expression(metaclass=_Expression):
self._set_parent(arg_key, value)
def _set_parent(self, arg_key, value):
- if isinstance(value, Expression):
+ if hasattr(value, "parent"):
value.parent = self
value.arg_key = arg_key
- elif isinstance(value, list):
+ elif type(value) is list:
for v in value:
- if isinstance(v, Expression):
+ if hasattr(v, "parent"):
v.parent = self
v.arg_key = arg_key
@@ -295,6 +301,17 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
+ def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]:
+ """Yields the key and expression for all arguments, exploding list args."""
+ for k, vs in self.args.items():
+ if type(vs) is list:
+ for v in vs:
+ if hasattr(v, "parent"):
+ yield k, v
+ else:
+ if hasattr(vs, "parent"):
+ yield k, vs
+
def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
@@ -319,7 +336,7 @@ class Expression(metaclass=_Expression):
Returns:
The generator object.
"""
- for expression, _, _ in self.walk(bfs=bfs):
+ for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
@@ -345,6 +362,11 @@ class Expression(metaclass=_Expression):
"""
return self.find_ancestor(Select)
+ @property
+ def same_parent(self):
+ """Returns if the parent is the same class as itself."""
+ return type(self.parent) is self.__class__
+
def root(self) -> Expression:
"""
Returns the root expression of this tree.
@@ -385,10 +407,8 @@ class Expression(metaclass=_Expression):
if prune and prune(self, parent, key):
return
- for k, v in self.args.items():
- for node in ensure_collection(v):
- if isinstance(node, Expression):
- yield from node.dfs(self, k, prune)
+ for k, v in self.iter_expressions():
+ yield from v.dfs(self, k, prune)
def bfs(self, prune=None):
"""
@@ -407,18 +427,15 @@ class Expression(metaclass=_Expression):
if prune and prune(item, parent, key):
continue
- if isinstance(item, Expression):
- for k, v in item.args.items():
- for node in ensure_collection(v):
- if isinstance(node, Expression):
- queue.append((node, item, k))
+ for k, v in item.iter_expressions():
+ queue.append((v, item, k))
def unnest(self):
"""
Returns the first non parenthesis child or self.
"""
expression = self
- while isinstance(expression, Paren):
+ while type(expression) is Paren:
expression = expression.this
return expression
@@ -434,7 +451,7 @@ class Expression(metaclass=_Expression):
"""
Returns unnested operands as a tuple.
"""
- return tuple(arg.unnest() for arg in self.args.values() if arg)
+ return tuple(arg.unnest() for _, arg in self.iter_expressions())
def flatten(self, unnest=True):
"""
@@ -442,8 +459,8 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
- if not isinstance(node, self.__class__):
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__):
+ if not type(node) is self.__class__:
yield node.unnest() if unnest else node
def __str__(self):
@@ -477,7 +494,7 @@ class Expression(metaclass=_Expression):
v._to_s(hide_missing=hide_missing, level=level + 1)
if hasattr(v, "_to_s")
else str(v)
- for v in ensure_collection(vs)
+ for v in ensure_list(vs)
if v is not None
)
for k, vs in self.args.items()
@@ -812,6 +829,10 @@ class Describe(Expression):
arg_types = {"this": True, "kind": False}
+class Pragma(Expression):
+ pass
+
+
class Set(Expression):
arg_types = {"expressions": False}
@@ -1170,6 +1191,7 @@ class Drop(Expression):
"temporary": False,
"materialized": False,
"cascade": False,
+ "constraints": False,
}
@@ -1232,11 +1254,11 @@ class Identifier(Expression):
def quoted(self):
return bool(self.args.get("quoted"))
- def __eq__(self, other):
- return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
-
- def __hash__(self):
- return hash((self.key, self.this.lower()))
+ @property
+ def hashable_args(self) -> t.Any:
+ if self.quoted and any(char.isupper() for char in self.this):
+ return (self.this, self.quoted)
+ return self.this.lower()
@property
def output_name(self):
@@ -1322,15 +1344,9 @@ class Limit(Expression):
class Literal(Condition):
arg_types = {"this": True, "is_string": True}
- def __eq__(self, other):
- return (
- isinstance(other, Literal)
- and self.this == other.this
- and self.args["is_string"] == other.args["is_string"]
- )
-
- def __hash__(self):
- return hash((self.key, self.this, self.args["is_string"]))
+ @property
+ def hashable_args(self) -> t.Any:
+ return (self.this, self.args.get("is_string"))
@classmethod
def number(cls, number) -> Literal:
@@ -1784,7 +1800,7 @@ class Subqueryable(Unionable):
instance = _maybe_copy(self, copy)
return Subquery(
this=instance,
- alias=TableAlias(this=to_identifier(alias)),
+ alias=TableAlias(this=to_identifier(alias)) if alias else None,
)
def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
@@ -2058,6 +2074,7 @@ class Lock(Expression):
class Select(Subqueryable):
arg_types = {
"with": False,
+ "kind": False,
"expressions": False,
"hint": False,
"distinct": False,
@@ -3595,6 +3612,21 @@ class Initcap(Func):
pass
+class JSONKeyValue(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
+class JSONObject(Func):
+ arg_types = {
+ "expressions": False,
+ "null_handling": False,
+ "unique_keys": False,
+ "return_type": False,
+ "format_json": False,
+ "encoding": False,
+ }
+
+
class JSONBContains(Binary):
_sql_names = ["JSONB_CONTAINS"]
@@ -3766,8 +3798,10 @@ class RegexpILike(Func):
arg_types = {"this": True, "expression": True, "flag": False}
+# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html
+# limit is the number of times a pattern is applied
class RegexpSplit(Func):
- arg_types = {"this": True, "expression": True}
+ arg_types = {"this": True, "expression": True, "limit": False}
class Repeat(Func):
@@ -3967,25 +4001,8 @@ class When(Func):
arg_types = {"matched": True, "source": False, "condition": False, "then": True}
-def _norm_args(expression):
- args = {}
-
- for k, arg in expression.args.items():
- if isinstance(arg, list):
- arg = [_norm_arg(a) for a in arg]
- if not arg:
- arg = None
- else:
- arg = _norm_arg(arg)
-
- if arg is not None and arg is not False:
- args[k] = arg
-
- return args
-
-
def _norm_arg(arg):
- return arg.lower() if isinstance(arg, str) else arg
+ return arg.lower() if type(arg) is str else arg
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
@@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None):
elif isinstance(name, str):
identifier = Identifier(
this=name,
- quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted,
+ quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted,
)
else:
raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}")
@@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
- table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
- return Column(this=column_name, table=table_name, **kwargs)
+ return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore
def alias_(
@@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts):
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
- schema: t.Optional[str | Identifier] = None,
+ db: t.Optional[str | Identifier] = None,
+ catalog: t.Optional[str | Identifier] = None,
quoted: t.Optional[bool] = None,
) -> Column:
"""
@@ -4681,7 +4698,8 @@ def column(
Args:
col: column name
table: table name
- schema: schema name
+ db: db name
+ catalog: catalog name
quoted: whether or not to force quote each part
Returns:
Column: column instance
@@ -4689,7 +4707,8 @@ def column(
return Column(
this=to_identifier(col, quoted=quoted),
table=to_identifier(table, quoted=quoted),
- schema=to_identifier(schema, quoted=quoted),
+ db=to_identifier(db, quoted=quoted),
+ catalog=to_identifier(catalog, quoted=quoted),
)
@@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs):
Replace children of an expression with the result of a lambda fun(child) -> exp.
"""
for k, v in expression.args.items():
- is_list_arg = isinstance(v, list)
+ is_list_arg = type(v) is list
child_nodes = v if is_list_arg else [v]
new_child_nodes = []
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index a6f4772..6871dd8 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -110,6 +110,10 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
+ # Whether or not limit and fetch are supported
+ # "ALL", "LIMIT", "FETCH"
+ LIMIT_FETCH = "ALL"
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -209,6 +213,7 @@ class Generator:
"_leading_comma",
"_max_text_width",
"_comments",
+ "_cache",
)
def __init__(
@@ -265,19 +270,28 @@ class Generator:
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
+ self._cache = None
- def generate(self, expression: t.Optional[exp.Expression]) -> str:
+ def generate(
+ self,
+ expression: t.Optional[exp.Expression],
+ cache: t.Optional[t.Dict[int, str]] = None,
+ ) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
expression: the syntax tree.
+ cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
Returns
the SQL string.
"""
+ if cache is not None:
+ self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
+ self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@@ -387,6 +401,12 @@ class Generator:
if key:
return self.sql(expression.args.get(key))
+ if self._cache is not None:
+ expression_id = hash(expression)
+
+ if expression_id in self._cache:
+ return self._cache[expression_id]
+
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@@ -407,7 +427,11 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- return self.maybe_comment(sql, expression) if self._comments and comment else sql
+ sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
+
+ if self._cache is not None:
+ self._cache[expression_id] = sql
+ return sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@@ -697,7 +721,8 @@ class Generator:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
+ constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@@ -733,9 +758,9 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
- text = text.lower() if self.normalize else text
+ text = text.lower() if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
- if expression.args.get("quoted") or should_identify(text, self.identify):
+ if expression.quoted or should_identify(text, self.identify):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@@ -1191,6 +1216,9 @@ class Generator:
)
return f"SET{expressions}"
+ def pragma_sql(self, expression: exp.Pragma) -> str:
+ return f"PRAGMA {self.sql(expression, 'this')}"
+
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
@@ -1299,6 +1327,15 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
+ limit = expression.args.get("limit")
+
+ if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
+ limit = exp.Limit(expression=limit.args.get("count"))
+ elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
+ limit = exp.Fetch(direction="FIRST", count=limit.expression)
+
+ fetch = isinstance(limit, exp.Fetch)
+
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
@@ -1315,14 +1352,16 @@ class Generator:
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
- self.sql(expression, "limit"),
- self.sql(expression, "offset"),
+ self.sql(expression, "offset") if fetch else self.sql(limit),
+ self.sql(limit) if fetch else self.sql(expression, "offset"),
self.sql(expression, "lock"),
self.sql(expression, "sample"),
sep="",
)
def select_sql(self, expression: exp.Select) -> str:
+ kind = expression.args.get("kind")
+ kind = f" AS {kind}" if kind else ""
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@@ -1330,7 +1369,7 @@ class Generator:
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
- f"SELECT{hint}{distinct}{expressions}",
+ f"SELECT{kind}{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@@ -1552,6 +1591,25 @@ class Generator:
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
+ def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
+ return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
+
+ def jsonobject_sql(self, expression: exp.JSONObject) -> str:
+ expressions = self.expressions(expression)
+ null_handling = expression.args.get("null_handling")
+ null_handling = f" {null_handling}" if null_handling else ""
+ unique_keys = expression.args.get("unique_keys")
+ if unique_keys is not None:
+ unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
+ else:
+ unique_keys = ""
+ return_type = self.sql(expression, "return_type")
+ return_type = f" RETURNING {return_type}" if return_type else ""
+ format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
+ encoding = self.sql(expression, "encoding")
+ encoding = f" ENCODING {encoding}" if encoding else ""
+ return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
+
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
@@ -1808,12 +1866,18 @@ class Generator:
def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
+ def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
+ return self.binary(expression, "ILIKE ANY")
+
def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
+ def likeany_sql(self, expression: exp.LikeAny) -> str:
+ return self.binary(expression, "LIKE ANY")
+
def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 6eff974..d44d7dd 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -59,7 +59,7 @@ def ensure_list(value):
"""
if value is None:
return []
- elif isinstance(value, (list, tuple)):
+ if isinstance(value, (list, tuple)):
return list(value)
return [value]
@@ -162,9 +162,7 @@ def camel_to_snake_case(name: str) -> str:
return CAMEL_CASE_PATTERN.sub("_", name).upper()
-def while_changing(
- expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
-) -> E:
+def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E:
"""
Applies a transformation to a given expression until a fix point is reached.
@@ -176,8 +174,13 @@ def while_changing(
The transformed expression.
"""
while True:
+ for n, *_ in reversed(tuple(expression.walk())):
+ n._hash = hash(n)
start = hash(expression)
expression = func(expression)
+
+ for n, *_ in expression.walk():
+ n._hash = None
if start == hash(expression):
break
return expression
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index c2d6655..99888c6 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,5 @@
from sqlglot import exp
-from sqlglot.helper import ensure_collection, ensure_list, subclasses
+from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -108,6 +108,7 @@ class TypeAnnotator:
exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
+ exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.GroupConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
@@ -116,6 +117,7 @@ class TypeAnnotator:
expr, exp.DataType.Type.VARCHAR
),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@@ -296,9 +298,6 @@ class TypeAnnotator:
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
- if not isinstance(expression, exp.Expression):
- return None
-
if expression.type:
return expression # We've already inferred the expression's type
@@ -311,9 +310,8 @@ class TypeAnnotator:
)
def _annotate_args(self, expression):
- for value in expression.args.values():
- for v in ensure_collection(value):
- self._maybe_annotate(v)
+ for _, value in expression.iter_expressions():
+ self._maybe_annotate(value)
return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index c5c780d..ef929ac 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
- and b.type.this != exp.DataType.Type.DATE
+ and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
_replace_cast(b, "date")
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 8e6a520..e0ddfa2 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -1,7 +1,6 @@
from sqlglot import expressions as exp
from sqlglot.optimizer.normalize import normalized
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.optimizer.simplify import simplify
def eliminate_joins(expression):
@@ -179,6 +178,4 @@ def join_condition(join):
for condition in conditions:
extract_condition(condition)
- on = simplify(on)
- remaining_condition = None if on == exp.true() else on
- return source_key, join_key, remaining_condition
+ return source_key, join_key, on
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 6f9db82..a39fe96 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -3,7 +3,6 @@ import itertools
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import build_scope
-from sqlglot.optimizer.simplify import simplify
def eliminate_subqueries(expression):
@@ -31,7 +30,6 @@ def eliminate_subqueries(expression):
eliminate_subqueries(expression.this)
return expression
- expression = simplify(expression)
root = build_scope(expression)
# Map of alias->Scope|Table
diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py
index 1cc76cf..fae1726 100644
--- a/sqlglot/optimizer/lower_identities.py
+++ b/sqlglot/optimizer/lower_identities.py
@@ -1,5 +1,4 @@
from sqlglot import exp
-from sqlglot.helper import ensure_collection
def lower_identities(expression):
@@ -40,13 +39,10 @@ def lower_identities(expression):
lower_identities(expression.right)
traversed |= {"this", "expression"}
- for k, v in expression.args.items():
+ for k, v in expression.iter_expressions():
if k in traversed:
continue
-
- for child in ensure_collection(v):
- if isinstance(child, exp.Expression):
- child.transform(_lower, copy=False)
+ v.transform(_lower, copy=False)
return expression
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 70172f4..c3467b2 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -3,7 +3,6 @@ from collections import defaultdict
from sqlglot import expressions as exp
from sqlglot.helper import find_new_name
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.optimizer.simplify import simplify
def merge_subqueries(expression, leave_tables_isolated=False):
@@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if set(exp.column_table_names(where.this)) <= sources:
from_or_join.on(where.this, copy=False)
- from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
- expression.set("where", simplify(expression.args.get("where")))
+ expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index f16f519..f2df230 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,29 +1,63 @@
+from __future__ import annotations
+
+import logging
+import typing as t
+
from sqlglot import exp
+from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+from sqlglot.optimizer.simplify import flatten, uniq_sort
+
+logger = logging.getLogger("sqlglot")
-def normalize(expression, dnf=False, max_distance=128):
+def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
"""
- Rewrite sqlglot AST into conjunctive normal form.
+ Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
- >>> normalize(expression).sql()
+ >>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Args:
- expression (sqlglot.Expression): expression to normalize
- dnf (bool): rewrite in disjunctive normal form instead
- max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ expression: expression to normalize
+ dnf: rewrite in disjunctive normal form instead.
+ max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
- expression = simplify(expression)
+ cache: t.Dict[int, str] = {}
+
+ for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
+ if isinstance(node, exp.Connector):
+ if normalized(node, dnf=dnf):
+ continue
+
+ distance = normalization_distance(node, dnf=dnf)
+
+ if distance > max_distance:
+ logger.info(
+ f"Skipping normalization because distance {distance} exceeds max {max_distance}"
+ )
+ return expression
+
+ root = node is expression
+ original = node.copy()
+ try:
+ node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ except OptimizeError as e:
+ logger.info(e)
+ node.replace(original)
+ if root:
+ return original
+ return expression
+
+ if root:
+ expression = node
- expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
- return simplify(expression)
+ return expression
def normalized(expression, dnf=False):
@@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
- len(list(expression.find_all(exp.Connector))) + 1
+ sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
@@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf):
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
- return [1]
+ return (1,)
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- return [
+ return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
- ]
+ )
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance):
+def distributive_law(expression, dnf, max_distance, cache=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
- if isinstance(expression.unnest(), exp.Connector):
- if normalization_distance(expression, dnf) > max_distance:
- return expression
+ if normalized(expression, dnf=dnf):
+ return expression
- to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+ distance = normalization_distance(expression, dnf=dnf)
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+ if distance > max_distance:
+ raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
@@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func)
- return _distribute(b, a, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
return expression
-def _distribute(a, b, from_func, to_func):
+def _distribute(a, b, from_func, to_func, cache):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- exp.paren(from_func(c, b.left)),
- exp.paren(from_func(c, b.right)),
+ uniq_sort(flatten(from_func(c, b.left)), cache),
+ uniq_sort(flatten(from_func(c, b.right)), cache),
),
)
else:
- a = to_func(from_func(a, b.left), from_func(a, b.right))
-
- return _simplify(a)
-
+ a = to_func(
+ uniq_sort(flatten(from_func(a, b.left)), cache),
+ uniq_sort(flatten(from_func(a, b.right)), cache),
+ )
-def _simplify(node):
- node = uniq_sort(flatten(node))
- exp.replace_children(node, _simplify)
- return node
+ return a
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index dc5ce44..8589657 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,6 +1,5 @@
from sqlglot import exp
from sqlglot.helper import tsort
-from sqlglot.optimizer.simplify import simplify
def optimize_joins(expression):
@@ -29,7 +28,6 @@ def optimize_joins(expression):
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
- on = on.replace(simplify(on))
if isinstance(on, exp.Connector):
for predicate in on.flatten():
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index d9d04be..62eb11e 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema
@@ -43,6 +44,7 @@ RULES = (
eliminate_ctes,
annotate_types,
canonicalize,
+ simplify,
)
@@ -78,7 +80,7 @@ def optimize(
Returns:
sqlglot.Expression: optimized expression
"""
- schema = ensure_schema(schema or sqlglot.schema)
+ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)
for rule in rules:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 66b3170..5e40cf3 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -30,11 +30,12 @@ def qualify_columns(expression, schema):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
_pop_table_column_aliases(scope.derived_tables)
- _expand_using(scope, resolver)
+ using_column_tables = _expand_using(scope, resolver)
_qualify_columns(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver)
+ _expand_stars(scope, resolver, using_column_tables)
_qualify_outputs(scope)
+ _expand_alias_refs(scope, resolver)
_expand_group_by(scope, resolver)
_expand_order_by(scope)
@@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
- joins = list(scope.expression.find_all(exp.Join))
+ joins = list(scope.find_all(exp.Join))
names = {join.this.alias for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
- # Mapping of automatically joined column names to source names
+ # Mapping of automatically joined column names to an ordered set of source names (dict).
column_tables = {}
for join in joins:
@@ -112,11 +113,12 @@ def _expand_using(scope, resolver):
)
)
- tables = column_tables.setdefault(identifier, [])
+ # Set all values in the dict to None, because we only care about the key ordering
+ tables = column_tables.setdefault(identifier, {})
if table not in tables:
- tables.append(table)
+ tables[table] = None
if join_table not in tables:
- tables.append(join_table)
+ tables[join_table] = None
join.args.pop("using")
join.set("on", exp.and_(*conditions))
@@ -134,11 +136,11 @@ def _expand_using(scope, resolver):
scope.replace(column, replacement)
+ return column_tables
-def _expand_group_by(scope, resolver):
- group = scope.expression.args.get("group")
- if not group:
- return
+
+def _expand_alias_refs(scope, resolver):
+ selects = {}
# Replace references to select aliases
def transform(node, *_):
@@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver):
node.set("table", table)
return node
- selects = {s.alias_or_name: s for s in scope.selects}
-
+ if not selects:
+ for s in scope.selects:
+ selects[s.alias_or_name] = s
select = selects.get(node.name)
+
if select:
scope.clear_cache()
if isinstance(select, exp.Alias):
@@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver):
return node
- group.transform(transform, copy=False)
+ for select in scope.expression.selects:
+ select.transform(transform, copy=False)
+
+ for modifier in ("where", "group"):
+ part = scope.expression.args.get(modifier)
+
+ if part:
+ part.transform(transform, copy=False)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
group.set("expressions", _expand_positional_references(scope, group.expressions))
scope.expression.set("group", group)
@@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver):
column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
columns_missing_from_scope = []
+
# Determine whether each reference in the order by clause is to a column or an alias.
- for ordered in scope.find_all(exp.Ordered):
- for column in ordered.find_all(exp.Column):
- if (
- not column.table
- and column.parent is not ordered
- and column.name in resolver.all_columns
- ):
- columns_missing_from_scope.append(column)
+ order = scope.expression.args.get("order")
+
+ if order:
+ for ordered in order.expressions:
+ for column in ordered.find_all(exp.Column):
+ if (
+ not column.table
+ and column.parent is not ordered
+ and column.name in resolver.all_columns
+ ):
+ columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
- for having in scope.find_all(exp.Having):
+ having = scope.expression.args.get("having")
+
+ if having:
for column in having.find_all(exp.Column):
if (
not column.table
@@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver):
column.set("table", column_table)
-def _expand_stars(scope, resolver):
+def _expand_stars(scope, resolver, using_column_tables):
"""Expand stars to lists of column selections"""
new_selections = []
except_columns = {}
replace_columns = {}
+ coalesced_columns = set()
for expression in scope.selects:
if isinstance(expression, exp.Star):
@@ -286,7 +311,20 @@ def _expand_stars(scope, resolver):
if columns and "*" not in columns:
table_id = id(table)
for name in columns:
- if name not in except_columns.get(table_id, set()):
+ if name in using_column_tables and table in using_column_tables[name]:
+ if name in coalesced_columns:
+ continue
+
+ coalesced_columns.add(name)
+ tables = using_column_tables[name]
+ coalesce = [exp.column(name, table=table) for table in tables]
+
+ new_selections.append(
+ exp.alias_(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
+ )
+ )
+ elif name not in except_columns.get(table_id, set()):
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table)
new_selections.append(alias(column, alias_) if alias_ != name else column)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 9c0768c..b582eb0 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -160,7 +160,7 @@ class Scope:
Yields:
exp.Expression: nodes
"""
- for expression, _, _ in self.walk(bfs=bfs):
+ for expression, *_ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
yield expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index f80484d..1ed3ca2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -5,11 +5,10 @@ from collections import deque
from decimal import Decimal
from sqlglot import exp
-from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
from sqlglot.helper import first, while_changing
-GENERATOR = Generator(normalize=True, identify=True)
+GENERATOR = Generator(normalize=True, identify="safe")
def simplify(expression):
@@ -28,18 +27,20 @@ def simplify(expression):
sqlglot.Expression: simplified expression
"""
+ cache = {}
+
def _simplify(expression, root=True):
node = expression
node = rewrite_between(node)
- node = uniq_sort(node)
- node = absorb_and_eliminate(node)
+ node = uniq_sort(node, cache, root)
+ node = absorb_and_eliminate(node, root)
exp.replace_children(node, lambda e: _simplify(e, False))
node = simplify_not(node)
node = flatten(node)
- node = simplify_connectors(node)
- node = remove_compliments(node)
+ node = simplify_connectors(node, root)
+ node = remove_compliments(node, root)
node.parent = expression.parent
- node = simplify_literals(node)
+ node = simplify_literals(node, root)
node = simplify_parens(node)
if root:
expression.replace(node)
@@ -70,7 +71,7 @@ def simplify_not(expression):
NOT (x AND y) -> NOT x OR NOT y
"""
if isinstance(expression, exp.Not):
- if isinstance(expression.this, exp.Null):
+ if is_null(expression.this):
return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
@@ -78,11 +79,11 @@ def simplify_not(expression):
return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
- if isinstance(condition, exp.Null):
+ if is_null(condition):
return exp.null()
if always_true(expression.this):
return exp.false()
- if expression.this == FALSE:
+ if is_false(expression.this):
return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
@@ -104,42 +105,42 @@ def flatten(expression):
return expression
-def simplify_connectors(expression):
+def simplify_connectors(expression, root=True):
def _simplify_connectors(expression, left, right):
- if isinstance(expression, exp.Connector):
- if left == right:
+ if left == right:
+ return left
+ if isinstance(expression, exp.And):
+ if is_false(left) or is_false(right):
+ return exp.false()
+ if is_null(left) or is_null(right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
return left
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return exp.false()
- if NULL in (left, right):
- return exp.null()
- if always_true(left) and always_true(right):
- return exp.true()
- if always_true(left):
- return right
- if always_true(right):
- return left
- return _simplify_comparison(expression, left, right)
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return exp.true()
- if left == FALSE and right == FALSE:
- return exp.false()
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return exp.null()
- if left == FALSE:
- return right
- if right == FALSE:
- return left
- return _simplify_comparison(expression, left, right, or_=True)
- return None
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if is_false(left) and is_false(right):
+ return exp.false()
+ if (
+ (is_null(left) and is_null(right))
+ or (is_null(left) and is_false(right))
+ or (is_false(left) and is_null(right))
+ ):
+ return exp.null()
+ if is_false(left):
+ return right
+ if is_false(right):
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
- return _flat_simplify(expression, _simplify_connectors)
+ if isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_connectors, root)
+ return expression
LT_LTE = (exp.LT, exp.LTE)
@@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False):
return None
-def remove_compliments(expression):
+def remove_compliments(expression, root=True):
"""
Removing compliments.
A AND NOT A -> FALSE
A OR NOT A -> TRUE
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -236,23 +237,23 @@ def remove_compliments(expression):
return expression
-def uniq_sort(expression):
+def uniq_sort(expression, cache=None, root=True):
"""
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
flattened = tuple(expression.flatten())
- deduped = {GENERATOR.generate(e): e for e in flattened}
+ deduped = {GENERATOR.generate(e, cache): e for e in flattened}
arr = tuple(deduped.items())
# check if the operands are already sorted, if not sort them
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
- expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ expression = result_func(*(e for _, e in sorted(arr)))
break
else:
# we didn't have to sort but maybe we need to dedup
@@ -262,7 +263,7 @@ def uniq_sort(expression):
return expression
-def absorb_and_eliminate(expression):
+def absorb_and_eliminate(expression, root=True):
"""
absorption:
A AND (A OR B) -> A
@@ -273,7 +274,7 @@ def absorb_and_eliminate(expression):
(A AND B) OR (A AND NOT B) -> A
(A OR B) AND (A OR NOT B) -> A
"""
- if isinstance(expression, exp.Connector):
+ if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
kind = exp.Or if isinstance(expression, exp.And) else exp.And
for a, b in itertools.permutations(expression.flatten(), 2):
@@ -302,9 +303,9 @@ def absorb_and_eliminate(expression):
return expression
-def simplify_literals(expression):
- if isinstance(expression, exp.Binary):
- return _flat_simplify(expression, _simplify_binary)
+def simplify_literals(expression, root=True):
+ if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
+ return _flat_simplify(expression, _simplify_binary, root)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b):
c = b
not_ = False
- if c == NULL:
+ if is_null(c):
if isinstance(a, exp.Literal):
return exp.true() if not_ else exp.false()
- if a == NULL:
+ if is_null(a):
return exp.false() if not_ else exp.true()
elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
return None
- elif NULL in (a, b):
+ elif is_null(a) or is_null(b):
return exp.null()
if a.is_number and b.is_number:
@@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
elif a.is_string and b.is_string:
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, a.this, b.this)
if boolean:
return boolean
@@ -381,7 +382,7 @@ def simplify_parens(expression):
and not isinstance(expression.this, exp.Select)
and (
not isinstance(expression.parent, (exp.Condition, exp.Binary))
- or isinstance(expression.this, (exp.Is, exp.Like))
+ or isinstance(expression.this, exp.Predicate)
or not isinstance(expression.this, exp.Binary)
)
):
@@ -400,13 +401,23 @@ def remove_where_true(expression):
def always_true(expression):
- return expression == TRUE or isinstance(expression, exp.Literal)
+ return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
+ expression, exp.Literal
+ )
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
+def is_false(a: exp.Expression) -> bool:
+ return type(a) is exp.Boolean and not a.this
+
+
+def is_null(a: exp.Expression) -> bool:
+ return type(a) is exp.Null
+
+
def eval_boolean(expression, a, b):
if isinstance(expression, (exp.EQ, exp.Is)):
return boolean_literal(a == b)
@@ -466,24 +477,27 @@ def boolean_literal(condition):
return exp.true() if condition else exp.false()
-def _flat_simplify(expression, simplifier):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
+def _flat_simplify(expression, simplifier, root=True):
+ if root or not expression.same_parent:
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
- while queue:
- a = queue.popleft()
+ while queue:
+ a = queue.popleft()
- for b in queue:
- result = simplifier(expression, a, b)
+ for b in queue:
+ result = simplifier(expression, a, b)
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
- if len(operands) < size:
- return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index a36251e..8269525 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -19,7 +19,7 @@ from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
-def parse_var_map(args):
+def parse_var_map(args: t.Sequence) -> exp.Expression:
keys = []
values = []
for i in range(0, len(args), 2):
@@ -31,6 +31,11 @@ def parse_var_map(args):
)
+def parse_like(args):
+ like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
+ return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
+
+
def binary_range_parser(
expr_type: t.Type[exp.Expression],
) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]:
@@ -77,6 +82,9 @@ class Parser(metaclass=_Parser):
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
+ "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
+ "IFNULL": exp.Coalesce.from_arg_list,
+ "LIKE": parse_like,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@@ -90,7 +98,6 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": parse_var_map,
- "IFNULL": exp.Coalesce.from_arg_list,
}
NO_PAREN_FUNCTIONS = {
@@ -211,6 +218,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
+ TokenType.FULL,
TokenType.IF,
TokenType.ISNULL,
TokenType.INTERVAL,
@@ -226,8 +234,10 @@ class Parser(metaclass=_Parser):
TokenType.ONLY,
TokenType.OPTIONS,
TokenType.ORDINALITY,
+ TokenType.PARTITION,
TokenType.PERCENT,
TokenType.PIVOT,
+ TokenType.PRAGMA,
TokenType.PRECEDING,
TokenType.RANGE,
TokenType.REFERENCES,
@@ -257,6 +267,7 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
+ TokenType.FULL,
TokenType.LEFT,
TokenType.NATURAL,
TokenType.OFFSET,
@@ -277,6 +288,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
+ TokenType.GLOB,
TokenType.IDENTIFIER,
TokenType.INDEX,
TokenType.ISNULL,
@@ -461,6 +473,7 @@ class Parser(metaclass=_Parser):
TokenType.INSERT: lambda self: self._parse_insert(),
TokenType.LOAD_DATA: lambda self: self._parse_load_data(),
TokenType.MERGE: lambda self: self._parse_merge(),
+ TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
TokenType.SET: lambda self: self._parse_set(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
@@ -662,6 +675,8 @@ class Parser(metaclass=_Parser):
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"EXTRACT": lambda self: self._parse_extract(),
+ "JSON_OBJECT": lambda self: self._parse_json_object(),
+ "LOG": lambda self: self._parse_logarithm(),
"POSITION": lambda self: self._parse_position(),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
@@ -719,6 +734,9 @@ class Parser(metaclass=_Parser):
CONVERT_TYPE_FIRST = False
+ LOG_BASE_FIRST = True
+ LOG_DEFAULTS_TO_LN = False
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1032,6 +1050,7 @@ class Parser(metaclass=_Parser):
temporary=temporary,
materialized=materialized,
cascade=self._match(TokenType.CASCADE),
+ constraints=self._match_text_seq("CONSTRAINTS"),
)
def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
@@ -1221,7 +1240,7 @@ class Parser(metaclass=_Parser):
if not identified_property:
break
- for p in ensure_collection(identified_property):
+ for p in ensure_list(identified_property):
properties.append(p)
if properties:
@@ -1704,6 +1723,11 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SELECT):
comments = self._prev_comments
+ kind = (
+ self._match(TokenType.ALIAS)
+ and self._match_texts(("STRUCT", "VALUE"))
+ and self._prev.text
+ )
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@@ -1722,6 +1746,7 @@ class Parser(metaclass=_Parser):
this = self.expression(
exp.Select,
+ kind=kind,
hint=hint,
distinct=distinct,
expressions=expressions,
@@ -2785,7 +2810,6 @@ class Parser(metaclass=_Parser):
this = seq_get(expressions, 0)
self._parse_query_modifiers(this)
- self._match_r_paren()
if isinstance(this, exp.Subqueryable):
this = self._parse_set_operations(
@@ -2794,7 +2818,9 @@ class Parser(metaclass=_Parser):
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
- this = self.expression(exp.Paren, this=this)
+ this = self.expression(exp.Paren, this=self._parse_set_operations(this))
+
+ self._match_r_paren()
if this and comments:
this.comments = comments
@@ -3318,6 +3344,60 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
+ self._match_text_seq("KEY")
+ key = self._parse_field()
+ self._match(TokenType.COLON)
+ self._match_text_seq("VALUE")
+ value = self._parse_field()
+ if not key and not value:
+ return None
+ return self.expression(exp.JSONKeyValue, this=key, expression=value)
+
+ def _parse_json_object(self) -> exp.Expression:
+ expressions = self._parse_csv(self._parse_json_key_value)
+
+ null_handling = None
+ if self._match_text_seq("NULL", "ON", "NULL"):
+ null_handling = "NULL ON NULL"
+ elif self._match_text_seq("ABSENT", "ON", "NULL"):
+ null_handling = "ABSENT ON NULL"
+
+ unique_keys = None
+ if self._match_text_seq("WITH", "UNIQUE"):
+ unique_keys = True
+ elif self._match_text_seq("WITHOUT", "UNIQUE"):
+ unique_keys = False
+
+ self._match_text_seq("KEYS")
+
+ return_type = self._match_text_seq("RETURNING") and self._parse_type()
+ format_json = self._match_text_seq("FORMAT", "JSON")
+ encoding = self._match_text_seq("ENCODING") and self._parse_var()
+
+ return self.expression(
+ exp.JSONObject,
+ expressions=expressions,
+ null_handling=null_handling,
+ unique_keys=unique_keys,
+ return_type=return_type,
+ format_json=format_json,
+ encoding=encoding,
+ )
+
+ def _parse_logarithm(self) -> exp.Expression:
+ # Default argument order is base, expression
+ args = self._parse_csv(self._parse_range)
+
+ if len(args) > 1:
+ if not self.LOG_BASE_FIRST:
+ args.reverse()
+ return exp.Log.from_arg_list(args)
+
+ return self.expression(
+ exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
+ )
+
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@@ -3654,7 +3734,7 @@ class Parser(metaclass=_Parser):
return parse_result
def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
- return self._parse_select() or self._parse_expression()
+ return self._parse_select() or self._parse_set_operations(self._parse_expression())
def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
return self._parse_set_operations(
@@ -3741,6 +3821,8 @@ class Parser(metaclass=_Parser):
expression = self._parse_foreign_key()
elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY):
expression = self._parse_primary_key()
+ else:
+ expression = None
return self.expression(exp.AddConstraint, this=this, expression=expression)
@@ -3799,12 +3881,15 @@ class Parser(metaclass=_Parser):
parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None
if parser:
- return self.expression(
- exp.AlterTable,
- this=this,
- exists=exists,
- actions=ensure_list(parser(self)),
- )
+ actions = ensure_list(parser(self))
+
+ if not self._curr:
+ return self.expression(
+ exp.AlterTable,
+ this=this,
+ exists=exists,
+ actions=actions,
+ )
return self._parse_as_command(start)
def _parse_merge(self) -> exp.Expression:
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index 40df39f..5fd96ef 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -175,7 +175,7 @@ class Step:
}
for projection in projections:
for i, e in aggregate.group.items():
- for child, _, _ in projection.walk():
+ for child, *_ in projection.walk():
if child == e:
child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f5d9f2b..8e39c7f 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -306,11 +306,11 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return self._type_mapping_cache[schema_type]
-def ensure_schema(schema: t.Any) -> Schema:
+def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema:
if isinstance(schema, Schema):
return schema
- return MappingSchema(schema)
+ return MappingSchema(schema, dialect=dialect)
def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index eb3c08f..e5b44e7 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -252,6 +252,7 @@ class TokenType(AutoName):
PERCENT = auto()
PIVOT = auto()
PLACEHOLDER = auto()
+ PRAGMA = auto()
PRECEDING = auto()
PRIMARY_KEY = auto()
PROCEDURE = auto()
@@ -346,7 +347,8 @@ class Token:
self.token_type = token_type
self.text = text
self.line = line
- self.col = max(col - len(text), 1)
+ self.col = col - len(text)
+ self.col = self.col if self.col > 1 else 1
self.comments = comments
def __repr__(self) -> str:
@@ -586,6 +588,7 @@ class Tokenizer(metaclass=_Tokenizer):
"PARTITIONED_BY": TokenType.PARTITION_BY,
"PERCENT": TokenType.PERCENT,
"PIVOT": TokenType.PIVOT,
+ "PRAGMA": TokenType.PRAGMA,
"PRECEDING": TokenType.PRECEDING,
"PRIMARY KEY": TokenType.PRIMARY_KEY,
"PROCEDURE": TokenType.PROCEDURE,
@@ -654,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LONG": TokenType.BIGINT,
"BIGINT": TokenType.BIGINT,
"INT8": TokenType.BIGINT,
+ "DEC": TokenType.DECIMAL,
"DECIMAL": TokenType.DECIMAL,
"MAP": TokenType.MAP,
"NULLABLE": TokenType.NULLABLE,
@@ -714,7 +718,7 @@ class Tokenizer(metaclass=_Tokenizer):
"VACUUM": TokenType.COMMAND,
}
- WHITE_SPACE: t.Dict[str, TokenType] = {
+ WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = {
" ": TokenType.SPACE,
"\t": TokenType.SPACE,
"\n": TokenType.BREAK,
@@ -813,11 +817,8 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end]
return ""
- def _line_break(self, char: t.Optional[str]) -> bool:
- return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore
-
def _advance(self, i: int = 1) -> None:
- if self._line_break(self._char):
+ if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._set_new_line()
self._col += i
@@ -939,7 +940,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore
self._advance(comment_end_size - 1)
else:
- while not self._end and not self._line_break(self._peek):
+ while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
self._comments.append(self._text[comment_start_size:]) # type: ignore