summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-08-06 07:48:11 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-08-06 07:48:11 +0000
commit379c6d1f52e1d311867c4f789dc389da1d9af898 (patch)
treec9ca62eb7b8b7e861cc67248850db220ad0881c9 /sqlglot
parentReleasing debian version 17.7.0-1. (diff)
downloadsqlglot-379c6d1f52e1d311867c4f789dc389da1d9af898.tar.xz
sqlglot-379c6d1f52e1d311867c4f789dc389da1d9af898.zip
Merging upstream version 17.9.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py7
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dialects/bigquery.py4
-rw-r--r--sqlglot/dialects/clickhouse.py2
-rw-r--r--sqlglot/dialects/dialect.py28
-rw-r--r--sqlglot/dialects/hive.py9
-rw-r--r--sqlglot/dialects/mysql.py79
-rw-r--r--sqlglot/dialects/oracle.py3
-rw-r--r--sqlglot/dialects/postgres.py6
-rw-r--r--sqlglot/dialects/presto.py2
-rw-r--r--sqlglot/dialects/spark.py7
-rw-r--r--sqlglot/dialects/spark2.py10
-rw-r--r--sqlglot/dialects/teradata.py2
-rw-r--r--sqlglot/dialects/tsql.py101
-rw-r--r--sqlglot/expressions.py100
-rw-r--r--sqlglot/generator.py72
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py6
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py18
-rw-r--r--sqlglot/optimizer/qualify_columns.py18
-rw-r--r--sqlglot/optimizer/qualify_tables.py4
-rw-r--r--sqlglot/optimizer/scope.py72
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py24
-rw-r--r--sqlglot/parser.py63
-rw-r--r--sqlglot/schema.py79
-rw-r--r--sqlglot/tokens.py2
-rw-r--r--sqlglot/transforms.py3
26 files changed, 599 insertions, 124 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 42801ac..be10f3d 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -67,19 +67,22 @@ schema = MappingSchema()
"""The default schema used by SQLGlot (e.g. in the optimizer)."""
-def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]:
+def parse(
+ sql: str, read: DialectType = None, dialect: DialectType = None, **opts
+) -> t.List[t.Optional[Expression]]:
"""
Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
Args:
sql: the SQL code string to parse.
read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
+ dialect: the SQL dialect (alias for read).
**opts: other `sqlglot.parser.Parser` options.
Returns:
The resulting syntax tree collection.
"""
- dialect = Dialect.get_or_raise(read)()
+ dialect = Dialect.get_or_raise(read or dialect)()
return dialect.parse(sql, **opts)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 1549a07..4002cfe 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -386,7 +386,7 @@ def input_file_name() -> Column:
def isnan(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "ISNAN")
+ return Column.invoke_expression_over_column(col, expression.IsNan)
def isnull(col: ColumnOrName) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index fd9965c..df9065f 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -211,6 +211,10 @@ class BigQuery(Dialect):
"TZH": "%z",
}
+ # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
+ # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
+ PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}
+
@classmethod
def normalize_identifier(cls, expression: E) -> E:
# In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 8f60df2..ce1a486 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -380,7 +380,7 @@ class ClickHouse(Dialect):
]
def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
- params = self.expressions(expression, "params", flat=True)
+ params = self.expressions(expression, key="params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"
def placeholder_sql(self, expression: exp.Placeholder) -> str:
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 8c84639..05e81ce 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -5,6 +5,7 @@ from enum import Enum
from sqlglot import exp
from sqlglot._typing import E
+from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
@@ -168,6 +169,10 @@ class Dialect(metaclass=_Dialect):
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
+ # Columns that are auto-generated by the engine corresponding to this dialect
+ # Such columns may be excluded from SELECT * queries, for example
+ PSEUDOCOLUMNS: t.Set[str] = set()
+
# Autofilled
tokenizer_class = Tokenizer
parser_class = Parser
@@ -497,6 +502,10 @@ def parse_date_delta_with_interval(
return None
interval = args[1]
+
+ if not isinstance(interval, exp.Interval):
+ raise ParseError(f"INTERVAL expression expected but got '{interval}'")
+
expression = interval.this
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
@@ -555,11 +564,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
- return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
+ return self.sql(exp.cast(expression.this, "timestamp"))
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
- return f"CAST({self.sql(expression, 'this')} AS DATE)"
+ return self.sql(exp.cast(expression.this, "date"))
def min_or_least(self: Generator, expression: exp.Min) -> str:
@@ -608,8 +617,9 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
- return f"CAST({str_to_time_sql(self, expression)} AS DATE)"
- return f"CAST({self.sql(expression, 'this')} AS DATE)"
+ return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
+
+ return self.sql(exp.cast(self.sql(expression, "this"), "date"))
return _ts_or_ds_to_date_sql
@@ -664,5 +674,15 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
return names
+def simplify_literal(expression: E, copy: bool = True) -> E:
+ if not isinstance(expression.expression, exp.Literal):
+ from sqlglot.optimizer.simplify import simplify
+
+ expression = exp.maybe_copy(expression, copy)
+ simplify(expression.expression)
+
+ return expression
+
+
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index e131434..4e84085 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -359,14 +359,16 @@ class Hive(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
+ EXTRACT_ALLOWS_QUOTES = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
- exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.BIT: "BOOLEAN",
exp.DataType.Type.DATETIME: "TIMESTAMP",
- exp.DataType.Type.VARBINARY: "BINARY",
+ exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.TIME: "TIMESTAMP",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
- exp.DataType.Type.BIT: "BOOLEAN",
+ exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
@@ -396,6 +398,7 @@ class Hive(Dialect):
exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
+ exp.IsNan: rename_func("ISNAN"),
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: _json_format_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5d65f77..a54f076 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
parse_date_delta_with_interval,
rename_func,
+ simplify_literal,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
@@ -303,6 +304,22 @@ class MySQL(Dialect):
"NAMES": lambda self: self._parse_set_item_names(),
}
+ CONSTRAINT_PARSERS = {
+ **parser.Parser.CONSTRAINT_PARSERS,
+ "FULLTEXT": lambda self: self._parse_index_constraint(kind="FULLTEXT"),
+ "INDEX": lambda self: self._parse_index_constraint(),
+ "KEY": lambda self: self._parse_index_constraint(),
+ "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
+ }
+
+ SCHEMA_UNNAMED_CONSTRAINTS = {
+ *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
+ "FULLTEXT",
+ "INDEX",
+ "KEY",
+ "SPATIAL",
+ }
+
PROFILE_TYPES = {
"ALL",
"BLOCK IO",
@@ -327,6 +344,57 @@ class MySQL(Dialect):
LOG_DEFAULTS_TO_LN = True
+ def _parse_index_constraint(
+ self, kind: t.Optional[str] = None
+ ) -> exp.IndexColumnConstraint:
+ if kind:
+ self._match_texts({"INDEX", "KEY"})
+
+ this = self._parse_id_var(any_token=False)
+ type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text
+ schema = self._parse_schema()
+
+ options = []
+ while True:
+ if self._match_text_seq("KEY_BLOCK_SIZE"):
+ self._match(TokenType.EQ)
+ opt = exp.IndexConstraintOption(key_block_size=self._parse_number())
+ elif self._match(TokenType.USING):
+ opt = exp.IndexConstraintOption(using=self._advance_any() and self._prev.text)
+ elif self._match_text_seq("WITH", "PARSER"):
+ opt = exp.IndexConstraintOption(parser=self._parse_var(any_token=True))
+ elif self._match(TokenType.COMMENT):
+ opt = exp.IndexConstraintOption(comment=self._parse_string())
+ elif self._match_text_seq("VISIBLE"):
+ opt = exp.IndexConstraintOption(visible=True)
+ elif self._match_text_seq("INVISIBLE"):
+ opt = exp.IndexConstraintOption(visible=False)
+ elif self._match_text_seq("ENGINE_ATTRIBUTE"):
+ self._match(TokenType.EQ)
+ opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
+ elif self._match_text_seq("ENGINE_ATTRIBUTE"):
+ self._match(TokenType.EQ)
+ opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
+ elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
+ self._match(TokenType.EQ)
+ opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
+ else:
+ opt = None
+
+ if not opt:
+ break
+
+ options.append(opt)
+
+ return self.expression(
+ exp.IndexColumnConstraint,
+ this=this,
+ schema=schema,
+ kind=kind,
+ type=type_,
+ options=options,
+ )
+
def _parse_show_mysql(
self,
this: str,
@@ -454,6 +522,7 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")),
exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
@@ -485,6 +554,16 @@ class MySQL(Dialect):
exp.DataType.Type.VARCHAR: "CHAR",
}
+ def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
+ # MySQL requires simple literal values for its LIMIT clause.
+ expression = simplify_literal(expression)
+ return super().limit_sql(expression, top=top)
+
+ def offset_sql(self, expression: exp.Offset) -> str:
+ # MySQL requires simple literal values for its OFFSET clause.
+ expression = simplify_literal(expression)
+ return super().offset_sql(expression)
+
def xor_sql(self, expression: exp.Xor) -> str:
if expression.expressions:
return self.expressions(expression, sep=" XOR ")
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 69da133..1f63e9f 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -30,6 +30,9 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
+ # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
+ RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
+
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
TIME_MAPPING = {
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index d11cbd7..ef100b1 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
rename_func,
+ simplify_literal,
str_position_sql,
timestamptrunc_sql,
timestrtotime_sql,
@@ -39,16 +40,13 @@ DATE_DIFF_FACTOR = {
def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
- from sqlglot.optimizer.simplify import simplify
-
this = self.sql(expression, "this")
unit = expression.args.get("unit")
- expression = simplify(expression.args["expression"])
+ expression = simplify_literal(expression.copy(), copy=False).expression
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")
- expression = expression.copy()
expression.args["is_string"] = True
return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}"
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 265c6e5..14ec3dd 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -192,6 +192,8 @@ class Presto(Dialect):
"START": TokenType.BEGIN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"ROW": TokenType.STRUCT,
+ "IPADDRESS": TokenType.IPADDRESS,
+ "IPPREFIX": TokenType.IPPREFIX,
}
class Parser(parser.Parser):
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 73f4370..b9aaa66 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
+from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
@@ -47,7 +48,11 @@ class Spark(Spark2):
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
}
- TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
+
+ TRANSFORMS = {
+ **Spark2.Generator.TRANSFORMS,
+ exp.StartsWith: rename_func("STARTSWITH"),
+ }
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index dcaa524..ceb48f8 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -19,9 +19,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
- if kind.upper() == "TABLE" and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
+ if (
+ kind.upper() == "TABLE"
+ and e.expression
+ and any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ )
):
return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
return create_with_partitions_sql(self, e)
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 4e8ffb4..3fac4f5 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -33,8 +33,10 @@ class Teradata(Dialect):
**tokens.Tokenizer.KEYWORDS,
"^=": TokenType.NEQ,
"BYTEINT": TokenType.SMALLINT,
+ "COLLECT": TokenType.COMMAND,
"GE": TokenType.GTE,
"GT": TokenType.GT,
+ "HELP": TokenType.COMMAND,
"INS": TokenType.INSERT,
"LE": TokenType.LTE,
"LT": TokenType.LT,
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 01d5001..0eb0906 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import datetime
import re
import typing as t
@@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
parse_date_delta,
rename_func,
+ timestrtotime_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
+DEFAULT_START_DATE = datetime.date(1900, 1, 1)
+
def _format_time_lambda(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
@@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
return f"STRING_AGG({self.format_args(this, separator)}){order}"
+def _parse_date_delta(
+ exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
+) -> t.Callable[[t.List], E]:
+ def inner_func(args: t.List) -> E:
+ unit = seq_get(args, 0)
+ if unit and unit_mapping:
+ unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name))
+
+ start_date = seq_get(args, 1)
+ if start_date and start_date.is_number:
+ # Numeric types are valid DATETIME values
+ if start_date.is_int:
+ adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this))
+ start_date = exp.Literal.string(adds.strftime("%F"))
+ else:
+ # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.
+ # This is not a problem when generating T-SQL code, it is when transpiling to other dialects.
+ return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit)
+
+ return exp_class(
+ this=exp.TimeStrToTime(this=seq_get(args, 2)),
+ expression=exp.TimeStrToTime(this=start_date),
+ unit=unit,
+ )
+
+ return inner_func
+
+
class TSQL(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
@@ -298,7 +330,6 @@ class TSQL(Dialect):
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
- "TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
@@ -307,10 +338,6 @@ class TSQL(Dialect):
"SYSTEM_USER": TokenType.CURRENT_USER,
}
- # TSQL allows @, # to appear as a variable/identifier prefix
- SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
- SINGLE_TOKENS.pop("#")
-
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -320,7 +347,7 @@ class TSQL(Dialect):
position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
- "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
+ "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"EOMONTH": _parse_eomonth,
@@ -518,6 +545,36 @@ class TSQL(Dialect):
expressions = self._parse_csv(self._parse_function_parameter)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+ def _parse_id_var(
+ self,
+ any_token: bool = True,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
+ is_temporary = self._match(TokenType.HASH)
+ is_global = is_temporary and self._match(TokenType.HASH)
+
+ this = super()._parse_id_var(any_token=any_token, tokens=tokens)
+ if this:
+ if is_global:
+ this.set("global", True)
+ elif is_temporary:
+ this.set("temporary", True)
+
+ return this
+
+ def _parse_create(self) -> exp.Create | exp.Command:
+ create = super()._parse_create()
+
+ if isinstance(create, exp.Create):
+ table = create.this.this if isinstance(create.this, exp.Schema) else create.this
+ if isinstance(table, exp.Table) and table.this.args.get("temporary"):
+ if not create.args.get("properties"):
+ create.set("properties", exp.Properties(expressions=[]))
+
+ create.args["properties"].append("expressions", exp.TemporaryProperty())
+
+ return create
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
@@ -526,9 +583,11 @@ class TSQL(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
- exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.TIMESTAMP: "DATETIME2",
+ exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}
@@ -552,6 +611,8 @@ class TSQL(Dialect):
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
e.this,
),
+ exp.TemporaryProperty: lambda self, e: "",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
}
@@ -564,6 +625,22 @@ class TSQL(Dialect):
LIMIT_FETCH = "FETCH"
+ def createable_sql(
+ self,
+ expression: exp.Create,
+ locations: dict[exp.Properties.Location, list[exp.Property]],
+ ) -> str:
+ sql = self.sql(expression, "this")
+ properties = expression.args.get("properties")
+
+ if sql[:1] != "#" and any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ ):
+ sql = f"#{sql}"
+
+ return sql
+
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
@@ -616,3 +693,13 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"ROLLBACK TRANSACTION{this}"
+
+ def identifier_sql(self, expression: exp.Identifier) -> str:
+ identifier = super().identifier_sql(expression)
+
+ if expression.args.get("global"):
+ identifier = f"##{identifier}"
+ elif expression.args.get("temporary"):
+ identifier = f"#{identifier}"
+
+ return identifier
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 9a6b440..f8e9fee 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -67,8 +67,9 @@ class Expression(metaclass=_Expression):
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
- _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
+ type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
+ meta: a dictionary that can be used to store useful metadata for a given expression.
Example:
>>> class Foo(Expression):
@@ -767,7 +768,7 @@ class Condition(Expression):
**opts,
) -> In:
return In(
- this=_maybe_copy(self, copy),
+ this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
@@ -781,7 +782,7 @@ class Condition(Expression):
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
return Between(
- this=_maybe_copy(self, copy),
+ this=maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)
@@ -990,7 +991,28 @@ class Uncache(Expression):
arg_types = {"this": True, "exists": False}
-class Create(Expression):
+class DDL(Expression):
+ @property
+ def ctes(self):
+ with_ = self.args.get("with")
+ if not with_:
+ return []
+ return with_.expressions
+
+ @property
+ def named_selects(self) -> t.List[str]:
+ if isinstance(self.expression, Subqueryable):
+ return self.expression.named_selects
+ return []
+
+ @property
+ def selects(self) -> t.List[Expression]:
+ if isinstance(self.expression, Subqueryable):
+ return self.expression.selects
+ return []
+
+
+class Create(DDL):
arg_types = {
"with": False,
"this": True,
@@ -1206,6 +1228,19 @@ class MergeTreeTTL(Expression):
}
+# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
+class IndexConstraintOption(Expression):
+ arg_types = {
+ "key_block_size": False,
+ "using": False,
+ "parser": False,
+ "comment": False,
+ "visible": False,
+ "engine_attr": False,
+ "secondary_engine_attr": False,
+ }
+
+
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@@ -1272,6 +1307,11 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
}
+# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
+class IndexColumnConstraint(ColumnConstraintKind):
+ arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False}
+
+
class InlineLengthColumnConstraint(ColumnConstraintKind):
pass
@@ -1496,7 +1536,7 @@ class JoinHint(Expression):
class Identifier(Expression):
- arg_types = {"this": True, "quoted": False}
+ arg_types = {"this": True, "quoted": False, "global": False, "temporary": False}
@property
def quoted(self) -> bool:
@@ -1525,7 +1565,7 @@ class Index(Expression):
}
-class Insert(Expression):
+class Insert(DDL):
arg_types = {
"with": False,
"this": True,
@@ -1892,6 +1932,10 @@ class EngineProperty(Property):
arg_types = {"this": True}
+class HeapProperty(Property):
+ arg_types = {}
+
+
class ToTableProperty(Property):
arg_types = {"this": True}
@@ -2182,7 +2226,7 @@ class Tuple(Expression):
**opts,
) -> In:
return In(
- this=_maybe_copy(self, copy),
+ this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
@@ -2212,7 +2256,7 @@ class Subqueryable(Unionable):
Returns:
Alias: the subquery
"""
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
if not isinstance(alias, Expression):
alias = TableAlias(this=to_identifier(alias)) if alias else None
@@ -2865,7 +2909,7 @@ class Select(Subqueryable):
self,
expression: ExpOrStr,
on: t.Optional[ExpOrStr] = None,
- using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None,
+ using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None,
append: bool = True,
join_type: t.Optional[str] = None,
join_alias: t.Optional[Identifier | str] = None,
@@ -2943,6 +2987,7 @@ class Select(Subqueryable):
arg="using",
append=append,
copy=copy,
+ into=Identifier,
**opts,
)
@@ -3092,7 +3137,7 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None
instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
@@ -3123,7 +3168,7 @@ class Select(Subqueryable):
Returns:
The new Create expression.
"""
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
table_expression = maybe_parse(
table,
into=Table,
@@ -3159,7 +3204,7 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
- inst = _maybe_copy(self, copy)
+ inst = maybe_copy(self, copy)
inst.set("locks", [Lock(update=update)])
return inst
@@ -3181,7 +3226,7 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
- inst = _maybe_copy(self, copy)
+ inst = maybe_copy(self, copy)
inst.set(
"hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
)
@@ -3376,6 +3421,8 @@ class DataType(Expression):
HSTORE = auto()
IMAGE = auto()
INET = auto()
+ IPADDRESS = auto()
+ IPPREFIX = auto()
INT = auto()
INT128 = auto()
INT256 = auto()
@@ -3987,7 +4034,7 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
instance.append(
"ifs",
If(
@@ -3998,7 +4045,7 @@ class Case(Func):
return instance
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
instance.set("default", maybe_parse(condition, copy=copy, **opts))
return instance
@@ -4263,6 +4310,10 @@ class Initcap(Func):
arg_types = {"this": True, "expression": False}
+class IsNan(Func):
+ _sql_names = ["IS_NAN", "ISNAN"]
+
+
class JSONKeyValue(Expression):
arg_types = {"this": True, "expression": True}
@@ -4549,6 +4600,11 @@ class StandardHash(Func):
arg_types = {"this": True, "expression": False}
+class StartsWith(Func):
+ _sql_names = ["STARTS_WITH", "STARTSWITH"]
+ arg_types = {"this": True, "expression": True}
+
+
class StrPosition(Func):
arg_types = {
"this": True,
@@ -4804,7 +4860,7 @@ def maybe_parse(
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
-def _maybe_copy(instance: E, copy: bool = True) -> E:
+def maybe_copy(instance: E, copy: bool = True) -> E:
return instance.copy() if copy else instance
@@ -4824,7 +4880,7 @@ def _apply_builder(
):
if _is_wrong_expression(expression, into):
expression = into(this=expression)
- instance = _maybe_copy(instance, copy)
+ instance = maybe_copy(instance, copy)
expression = maybe_parse(
sql_or_expression=expression,
prefix=prefix,
@@ -4848,7 +4904,7 @@ def _apply_child_list_builder(
properties=None,
**opts,
):
- instance = _maybe_copy(instance, copy)
+ instance = maybe_copy(instance, copy)
parsed = []
for expression in expressions:
if expression is not None:
@@ -4887,7 +4943,7 @@ def _apply_list_builder(
dialect=None,
**opts,
):
- inst = _maybe_copy(instance, copy)
+ inst = maybe_copy(instance, copy)
expressions = [
maybe_parse(
@@ -4923,7 +4979,7 @@ def _apply_conjunction_builder(
if not expressions:
return instance
- inst = _maybe_copy(instance, copy)
+ inst = maybe_copy(instance, copy)
existing = inst.args.get(arg)
if append and existing is not None:
@@ -5398,7 +5454,7 @@ def to_identifier(name, quoted=None, copy=True):
return None
if isinstance(name, Identifier):
- identifier = _maybe_copy(name, copy)
+ identifier = maybe_copy(name, copy)
elif isinstance(name, str):
identifier = Identifier(
this=name,
@@ -5735,7 +5791,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
- return _maybe_copy(value, copy)
+ return maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 40ba88e..ed0a681 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -68,6 +68,7 @@ class Generator:
exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}",
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
exp.ExternalProperty: lambda self, e: "EXTERNAL",
+ exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
@@ -161,6 +162,9 @@ class Generator:
# Whether or not to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
+ # Whether or not to generate an unquoted value for EXTRACT's date part argument
+ EXTRACT_ALLOWS_QUOTES = True
+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
@@ -224,6 +228,7 @@ class Generator:
exp.FallbackProperty: exp.Properties.Location.POST_NAME,
exp.FileFormatProperty: exp.Properties.Location.POST_WITH,
exp.FreespaceProperty: exp.Properties.Location.POST_NAME,
+ exp.HeapProperty: exp.Properties.Location.POST_WITH,
exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME,
exp.JournalProperty: exp.Properties.Location.POST_NAME,
exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA,
@@ -265,9 +270,12 @@ class Generator:
# Expressions whose comments are separated from them for better formatting
WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Delete,
exp.Drop,
exp.From,
+ exp.Insert,
exp.Select,
+ exp.Update,
exp.Where,
exp.With,
)
@@ -985,8 +993,9 @@ class Generator:
) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- expressions = self.wrap(expressions) if wrapped else expressions
- return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
+ if expressions:
+ expressions = self.wrap(expressions) if wrapped else expressions
+ return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}"
return ""
def with_properties(self, properties: exp.Properties) -> str:
@@ -1905,7 +1914,7 @@ class Generator:
return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
def extract_sql(self, expression: exp.Extract) -> str:
- this = self.sql(expression, "this")
+ this = self.sql(expression, "this") if self.EXTRACT_ALLOWS_QUOTES else expression.this.name
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
@@ -2370,7 +2379,12 @@ class Generator:
elif arg_value is not None:
args.append(arg_value)
- return self.func(expression.sql_name(), *args)
+ if self.normalize_functions:
+ name = expression.sql_name()
+ else:
+ name = (expression._meta and expression.meta.get("name")) or expression.sql_name()
+
+ return self.func(name, *args)
def func(
self,
@@ -2412,7 +2426,7 @@ class Generator:
return ""
if flat:
- return sep.join(self.sql(e) for e in expressions)
+ return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql)
num_sqls = len(expressions)
@@ -2423,6 +2437,9 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
+ if not sql:
+ continue
+
comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
@@ -2562,6 +2579,51 @@ class Generator:
record_reader = f" RECORDREADER {record_reader}" if record_reader else ""
return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}"
+ def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str:
+ key_block_size = self.sql(expression, "key_block_size")
+ if key_block_size:
+ return f"KEY_BLOCK_SIZE = {key_block_size}"
+
+ using = self.sql(expression, "using")
+ if using:
+ return f"USING {using}"
+
+ parser = self.sql(expression, "parser")
+ if parser:
+ return f"WITH PARSER {parser}"
+
+ comment = self.sql(expression, "comment")
+ if comment:
+ return f"COMMENT {comment}"
+
+ visible = expression.args.get("visible")
+ if visible is not None:
+ return "VISIBLE" if visible else "INVISIBLE"
+
+ engine_attr = self.sql(expression, "engine_attr")
+ if engine_attr:
+ return f"ENGINE_ATTRIBUTE = {engine_attr}"
+
+ secondary_engine_attr = self.sql(expression, "secondary_engine_attr")
+ if secondary_engine_attr:
+ return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}"
+
+ self.unsupported("Unsupported index constraint option.")
+ return ""
+
+ def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
+ kind = self.sql(expression, "kind")
+ kind = f"{kind} INDEX" if kind else "INDEX"
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ type_ = self.sql(expression, "type")
+ type_ = f" USING {type_}" if type_ else ""
+ schema = self.sql(expression, "schema")
+ schema = f" {schema}" if schema else ""
+ options = self.expressions(expression, key="options", sep=" ")
+ options = f" {options}" if options else ""
+ return f"{kind}{this}{type_}{schema}{options}"
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index 728493d..af42f25 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken):
def _eliminate_derived_table(scope, existing_ctes, taken):
- # This ensures we don't drop the "pivot" arg from a pivoted subquery
- if scope.parent.pivots:
+ # This makes sure that we don't:
+ # - drop the "pivot" arg from a pivoted subquery
+ # - eliminate a lateral correlated subquery
+ if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral):
return None
parent = scope.expression.parent
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 99e605d..9d4860e 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -1,8 +1,23 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
+@t.overload
def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
+ ...
+
+
+@t.overload
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
+ ...
+
+
+def normalize_identifiers(expression, dialect=None):
"""
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
@@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> normalize_identifiers(expression).sql()
'SELECT bar.a AS a FROM "Foo".bar'
+ >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake")
+ 'FOO'
Args:
expression: The expression to transform.
@@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
Returns:
The transformed expression.
"""
+ expression = exp.maybe_parse(expression, dialect=dialect)
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 2657188..9c34cef 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -39,6 +39,7 @@ def qualify_columns(
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
+ pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema, infer_schema=infer_schema)
@@ -55,7 +56,7 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver, using_column_tables)
+ _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
_qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
def _expand_stars(
- scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
+ scope: Scope,
+ resolver: Resolver,
+ using_column_tables: t.Dict[str, t.Any],
+ pseudocolumns: t.Set[str],
) -> None:
"""Expand stars to lists of column selections"""
@@ -367,14 +371,8 @@ def _expand_stars(
columns = resolver.get_source_columns(table, only_visible=True)
- # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
- # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
- if resolver.schema.dialect == "bigquery":
- columns = [
- name
- for name in columns
- if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
- ]
+ if pseudocolumns:
+ columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 31c9cc0..68aebdb 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -80,7 +80,9 @@ def qualify_tables(
header = next(reader)
columns = next(reader)
schema.add_table(
- source, {k: type(v).__name__ for k, v in zip(header, columns)}
+ source,
+ {k: type(v).__name__ for k, v in zip(header, columns)},
+ match_depth=False,
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index a7dab35..fb12384 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -435,7 +435,10 @@ class Scope:
@property
def is_correlated_subquery(self):
"""Determine if this scope is a correlated subquery"""
- return bool(self.is_subquery and self.external_columns)
+ return bool(
+ (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
+ and self.external_columns
+ )
def rename_source(self, old_name, new_name):
"""Rename a source in this scope"""
@@ -486,7 +489,7 @@ class Scope:
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
- Traverse an expression by it's "scopes".
+ Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
@@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
Returns:
list[Scope]: scope instances
"""
- if not isinstance(expression, exp.Unionable):
- return []
- return list(_traverse_scope(Scope(expression)))
+ if isinstance(expression, exp.Unionable) or (
+ isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
+ ):
+ return list(_traverse_scope(Scope(expression)))
+
+ return []
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
@@ -539,7 +545,9 @@ def _traverse_scope(scope):
elif isinstance(scope.expression, exp.Table):
yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
- pass
+ yield from _traverse_udtfs(scope)
+ elif isinstance(scope.expression, exp.DDL):
+ yield from _traverse_ddl(scope)
else:
logger.warning(
"Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
@@ -576,10 +584,10 @@ def _traverse_ctes(scope):
for cte in scope.ctes:
recursive_scope = None
- # if the scope is a recursive cte, it must be in the form of
- # base_case UNION recursive. thus the recursive scope is the first
- # section of the union.
- if scope.expression.args["with"].recursive:
+ # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
+ # thus the recursive scope is the first section of the union.
+ with_ = scope.expression.args.get("with")
+ if with_ and with_.recursive:
union = cte.this
if isinstance(union, exp.Union):
@@ -692,8 +700,7 @@ def _traverse_tables(scope):
# This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
# Until then, this means that only a single, unaliased derived table is allowed (rather,
# the latest one wins.
- alias = expression.alias
- sources[alias] = child_scope
+ sources[expression.alias] = child_scope
# append the final child_scope yielded
scopes.append(child_scope)
@@ -711,6 +718,47 @@ def _traverse_subqueries(scope):
scope.subquery_scopes.append(top)
+def _traverse_udtfs(scope):
+ if isinstance(scope.expression, exp.Unnest):
+ expressions = scope.expression.expressions
+ elif isinstance(scope.expression, exp.Lateral):
+ expressions = [scope.expression.this]
+ else:
+ expressions = []
+
+ sources = {}
+ for expression in expressions:
+ if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
+ top = None
+ for child_scope in _traverse_scope(
+ scope.branch(
+ expression,
+ scope_type=ScopeType.DERIVED_TABLE,
+ outer_column_list=expression.alias_column_names,
+ )
+ ):
+ yield child_scope
+ top = child_scope
+ sources[expression.alias] = child_scope
+
+ scope.derived_table_scopes.append(top)
+ scope.table_scopes.append(top)
+
+ scope.sources.update(sources)
+
+
+def _traverse_ddl(scope):
+ yield from _traverse_ctes(scope)
+
+ query_scope = scope.branch(
+ scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
+ )
+ query_scope._collect()
+ query_scope._ctes = scope.ctes + query_scope._ctes
+
+ yield from _traverse_scope(query_scope)
+
+
def walk_in_scope(expression, bfs=True):
"""
Returns a generator object which visits all nodes in the syntrax tree, stopping at
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 09e3f2a..816f5fb 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name):
if not predicate or parent_select is not predicate.parent_select:
return
- # this subquery returns a scalar and can just be converted to a cross join
+ # This subquery returns a scalar and can just be converted to a cross join
if not isinstance(predicate, (exp.In, exp.Any)):
- having = predicate.find_ancestor(exp.Having)
column = exp.column(select.selects[0].alias_or_name, alias)
- if having and having.parent_select is parent_select:
+
+ clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join)
+ clause_parent_select = clause.parent_select if clause else None
+
+ if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or (
+ (not clause or clause_parent_select is not parent_select)
+ and (
+ parent_select.args.get("group")
+ or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
+ )
+ ):
column = exp.Max(this=column)
- _replace(select.parent, column)
- parent_select.join(
- select,
- join_type="CROSS",
- join_alias=alias,
- copy=False,
- )
+ _replace(select.parent, column)
+ parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False)
return
if select.find(exp.Limit, exp.Offset):
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 5adec77..f714c8d 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -185,6 +185,8 @@ class Parser(metaclass=_Parser):
TokenType.VARIANT,
TokenType.OBJECT,
TokenType.INET,
+ TokenType.IPADDRESS,
+ TokenType.IPPREFIX,
TokenType.ENUM,
*NESTED_TYPE_TOKENS,
}
@@ -603,6 +605,7 @@ class Parser(metaclass=_Parser):
"FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs),
"FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty),
"FREESPACE": lambda self: self._parse_freespace(),
+ "HEAP": lambda self: self.expression(exp.HeapProperty),
"IMMUTABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE")
),
@@ -832,6 +835,7 @@ class Parser(metaclass=_Parser):
UNNEST_COLUMN_ONLY: bool = False
ALIAS_POST_TABLESAMPLE: bool = False
STRICT_STRING_CONCAT = False
+ NORMALIZE_FUNCTIONS = "upper"
NULL_ORDERING: str = "nulls_are_small"
SHOW_TRIE: t.Dict = {}
SET_TRIE: t.Dict = {}
@@ -1187,7 +1191,7 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists(not_=True)
this = None
- expression = None
+ expression: t.Optional[exp.Expression] = None
indexes = None
no_schema_binding = None
begin = None
@@ -1207,12 +1211,16 @@ class Parser(metaclass=_Parser):
extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
- begin = self._match(TokenType.BEGIN)
- return_ = self._match_text_seq("RETURN")
- expression = self._parse_statement()
- if return_:
- expression = self.expression(exp.Return, this=expression)
+ if self._match(TokenType.COMMAND):
+ expression = self._parse_as_command(self._prev)
+ else:
+ begin = self._match(TokenType.BEGIN)
+ return_ = self._match_text_seq("RETURN")
+ expression = self._parse_statement()
+
+ if return_:
+ expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
@@ -1692,6 +1700,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Describe, this=this, kind=kind)
def _parse_insert(self) -> exp.Insert:
+ comments = ensure_list(self._prev_comments)
overwrite = self._match(TokenType.OVERWRITE)
ignore = self._match(TokenType.IGNORE)
local = self._match_text_seq("LOCAL")
@@ -1709,6 +1718,7 @@ class Parser(metaclass=_Parser):
alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text
self._match(TokenType.INTO)
+ comments += ensure_list(self._prev_comments)
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)
@@ -1716,6 +1726,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Insert,
+ comments=comments,
this=this,
exists=self._parse_exists(),
partition=self._parse_partition(),
@@ -1840,6 +1851,7 @@ class Parser(metaclass=_Parser):
# This handles MySQL's "Multiple-Table Syntax"
# https://dev.mysql.com/doc/refman/8.0/en/delete.html
tables = None
+ comments = self._prev_comments
if not self._match(TokenType.FROM, advance=False):
tables = self._parse_csv(self._parse_table) or None
@@ -1847,6 +1859,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
+ comments=comments,
tables=tables,
this=self._match(TokenType.FROM) and self._parse_table(joins=True),
using=self._match(TokenType.USING) and self._parse_table(joins=True),
@@ -1856,11 +1869,13 @@ class Parser(metaclass=_Parser):
)
def _parse_update(self) -> exp.Update:
+ comments = self._prev_comments
this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS)
expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality)
returning = self._parse_returning()
return self.expression(
exp.Update,
+ comments=comments,
**{ # type: ignore
"this": this,
"expressions": expressions,
@@ -2235,7 +2250,12 @@ class Parser(metaclass=_Parser):
return None
if not this:
- this = self._parse_function() or self._parse_id_var(any_token=False)
+ this = (
+ self._parse_unnest()
+ or self._parse_function()
+ or self._parse_id_var(any_token=False)
+ )
+
while self._match(TokenType.DOT):
this = exp.Dot(
this=this,
@@ -3341,7 +3361,10 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(lambda: self._parse_lambda(alias=alias))
if function and not anonymous:
- this = self.validate_expression(function(args), args)
+ func = self.validate_expression(function(args), args)
+ if not self.NORMALIZE_FUNCTIONS:
+ func.meta["name"] = this
+ this = func
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -3842,13 +3865,11 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(self._parse_conjunction)
index = self._index
- if not self._match(TokenType.R_PAREN):
+ if not self._match(TokenType.R_PAREN) and args:
# postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]])
- return self.expression(
- exp.GroupConcat,
- this=seq_get(args, 0),
- separator=self._parse_order(this=seq_get(args, 1)),
- )
+ # bigquery: STRING_AGG([DISTINCT] expression [, separator] [ORDER BY key [{ASC | DESC}] [, ... ]] [LIMIT n])
+ args[-1] = self._parse_limit(this=self._parse_order(this=args[-1]))
+ return self.expression(exp.GroupConcat, this=args[0], separator=seq_get(args, 1))
# Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]).
# This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that
@@ -4172,7 +4193,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
- return self.expression(
+ window = self.expression(
exp.Window,
this=this,
partition_by=partition,
@@ -4183,6 +4204,12 @@ class Parser(metaclass=_Parser):
first=first,
)
+ # This covers Oracle's FIRST/LAST syntax: aggregate KEEP (...) OVER (...)
+ if self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS, advance=False):
+ return self._parse_window(window, alias=alias)
+
+ return window
+
def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)
@@ -4276,19 +4303,19 @@ class Parser(metaclass=_Parser):
def _parse_null(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NULL):
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
- return None
+ return self._parse_placeholder()
def _parse_boolean(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.TRUE):
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE):
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
- return None
+ return self._parse_placeholder()
def _parse_star(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.STAR):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
- return None
+ return self._parse_placeholder()
def _parse_parameter(self) -> exp.Parameter:
wrapped = self._match(TokenType.L_BRACE)
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 12cf0b1..7a3c88b 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -31,14 +31,19 @@ class Schema(abc.ABC):
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ match_depth: bool = True,
) -> None:
"""
Register or update a table. Some implementing classes may require column information to also be provided.
+ The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
+ match_depth: whether to enforce that the table must match the schema's depth or not.
"""
@abc.abstractmethod
@@ -47,6 +52,7 @@ class Schema(abc.ABC):
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> t.List[str]:
"""
Get the column names for a table.
@@ -55,6 +61,7 @@ class Schema(abc.ABC):
table: the `Table` expression instance.
only_visible: whether to include invisible columns.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The list of column names.
@@ -66,6 +73,7 @@ class Schema(abc.ABC):
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> exp.DataType:
"""
Get the `sqlglot.exp.DataType` type of a column in the schema.
@@ -74,6 +82,7 @@ class Schema(abc.ABC):
table: the source table.
column: the target column.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
Returns:
The resulting column type.
@@ -99,7 +108,7 @@ class AbstractMappingSchema(t.Generic[T]):
) -> None:
self.mapping = mapping or {}
self.mapping_trie = new_trie(
- tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth())
+ tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth())
)
self._supported_table_args: t.Tuple[str, ...] = tuple()
@@ -107,13 +116,13 @@ class AbstractMappingSchema(t.Generic[T]):
def empty(self) -> bool:
return not self.mapping
- def _depth(self) -> int:
+ def depth(self) -> int:
return dict_depth(self.mapping)
@property
def supported_table_args(self) -> t.Tuple[str, ...]:
if not self._supported_table_args and self.mapping:
- depth = self._depth()
+ depth = self.depth()
if not depth: # None
self._supported_table_args = tuple()
@@ -191,6 +200,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
self.visible = visible or {}
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
+ self._depth = 0
super().__init__(self._normalize(schema or {}))
@@ -200,6 +210,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
+ normalize=mapping_schema.normalize,
)
def copy(self, **kwargs) -> MappingSchema:
@@ -208,6 +219,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
+ "normalize": self.normalize,
**kwargs,
}
)
@@ -217,19 +229,30 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
column_mapping: t.Optional[ColumnMapping] = None,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ match_depth: bool = True,
) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
+ The added table must have the necessary number of qualifiers in its path to match the schema's nesting level.
Args:
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
dialect: the SQL dialect that will be used to parse `table` if it's a string.
+ normalize: whether to normalize identifiers according to the dialect of interest.
+ match_depth: whether to enforce that the table must match the schema's depth or not.
"""
- normalized_table = self._normalize_table(table, dialect=dialect)
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
+
+ if match_depth and not self.empty and len(normalized_table.parts) != self.depth():
+ raise SchemaError(
+ f"Table {normalized_table.sql(dialect=self.dialect)} must match the "
+ f"schema's nesting level: {self.depth()}."
+ )
normalized_column_mapping = {
- self._normalize_name(key, dialect=dialect): value
+ self._normalize_name(key, dialect=dialect, normalize=normalize): value
for key, value in ensure_column_mapping(column_mapping).items()
}
@@ -247,8 +270,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
only_visible: bool = False,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> t.List[str]:
- normalized_table = self._normalize_table(table, dialect=dialect)
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
schema = self.find(normalized_table)
if schema is None:
@@ -265,11 +289,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
table: exp.Table | str,
column: exp.Column,
dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
) -> exp.DataType:
- normalized_table = self._normalize_table(table, dialect=dialect)
+ normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize)
normalized_column_name = self._normalize_name(
- column if isinstance(column, str) else column.this, dialect=dialect
+ column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize
)
table_schema = self.find(normalized_table, raise_on_missing=False)
@@ -293,12 +318,16 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
Returns:
The normalized schema mapping.
"""
+ normalized_mapping: t.Dict = {}
flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
- normalized_mapping: t.Dict = {}
for keys in flattened_schema:
columns = nested_get(schema, *zip(keys, keys))
- assert columns is not None
+
+ if not isinstance(columns, dict):
+ raise SchemaError(
+ f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}."
+ )
normalized_keys = [
self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys
@@ -312,7 +341,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return normalized_mapping
- def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
+ def _normalize_table(
+ self,
+ table: exp.Table | str,
+ dialect: DialectType = None,
+ normalize: t.Optional[bool] = None,
+ ) -> exp.Table:
normalized_table = exp.maybe_parse(
table, into=exp.Table, dialect=dialect or self.dialect, copy=True
)
@@ -322,15 +356,24 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
if isinstance(value, (str, exp.Identifier)):
normalized_table.set(
arg,
- exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)),
+ exp.to_identifier(
+ self._normalize_name(
+ value, dialect=dialect, is_table=True, normalize=normalize
+ )
+ ),
)
return normalized_table
def _normalize_name(
- self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False
+ self,
+ name: str | exp.Identifier,
+ dialect: DialectType = None,
+ is_table: bool = False,
+ normalize: t.Optional[bool] = None,
) -> str:
dialect = dialect or self.dialect
+ normalize = self.normalize if normalize is None else normalize
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
@@ -338,16 +381,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
return name if isinstance(name, str) else name.name
name = identifier.name
- if not self.normalize:
+ if not normalize:
return name
# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name
- def _depth(self) -> int:
- # The columns themselves are a mapping, but we don't want to include those
- return super()._depth() - 1
+ def depth(self) -> int:
+ if not self.empty and not self._depth:
+ # The columns themselves are a mapping, but we don't want to include those
+ self._depth = super().depth() - 1
+ return self._depth
def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index a19ebaa..729e47f 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -147,6 +147,8 @@ class TokenType(AutoName):
VARIANT = auto()
OBJECT = auto()
INET = auto()
+ IPADDRESS = auto()
+ IPPREFIX = auto()
ENUM = auto()
# keywords
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 1e6cfc8..7c7c2a7 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -100,7 +100,8 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
qualify_filters = expression.args["qualify"].pop().this
- for expr in qualify_filters.find_all((exp.Window, exp.Column)):
+ select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column)
+ for expr in qualify_filters.find_all(select_candidates):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
expression.select(exp.alias_(expr, alias), copy=False)