summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py4
-rw-r--r--sqlglot/dataframe/sql/functions.py4
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py51
-rw-r--r--sqlglot/dialects/clickhouse.py20
-rw-r--r--sqlglot/dialects/databricks.py11
-rw-r--r--sqlglot/dialects/dialect.py42
-rw-r--r--sqlglot/dialects/drill.py3
-rw-r--r--sqlglot/dialects/duckdb.py66
-rw-r--r--sqlglot/dialects/hive.py45
-rw-r--r--sqlglot/dialects/mysql.py3
-rw-r--r--sqlglot/dialects/oracle.py13
-rw-r--r--sqlglot/dialects/postgres.py18
-rw-r--r--sqlglot/dialects/presto.py30
-rw-r--r--sqlglot/dialects/redshift.py4
-rw-r--r--sqlglot/dialects/snowflake.py4
-rw-r--r--sqlglot/dialects/spark.py240
-rw-r--r--sqlglot/dialects/spark2.py238
-rw-r--r--sqlglot/dialects/sqlite.py43
-rw-r--r--sqlglot/dialects/starrocks.py1
-rw-r--r--sqlglot/dialects/tableau.py3
-rw-r--r--sqlglot/dialects/teradata.py3
-rw-r--r--sqlglot/dialects/tsql.py11
-rw-r--r--sqlglot/expressions.py225
-rw-r--r--sqlglot/generator.py64
-rw-r--r--sqlglot/optimizer/eliminate_joins.py2
-rw-r--r--sqlglot/optimizer/expand_laterals.py4
-rw-r--r--sqlglot/optimizer/normalize.py2
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/qualify_columns.py26
-rw-r--r--sqlglot/optimizer/qualify_tables.py18
-rw-r--r--sqlglot/optimizer/scope.py6
-rw-r--r--sqlglot/optimizer/simplify.py17
-rw-r--r--sqlglot/parser.py121
-rw-r--r--sqlglot/tokens.py95
-rw-r--r--sqlglot/transforms.py70
36 files changed, 947 insertions, 563 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 42d89d1..f7440e0 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -50,7 +50,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.7.1"
+__version__ = "12.2.0"
pretty = False
"""Whether to format generated SQL by default."""
@@ -181,7 +181,7 @@ def transpile(
Returns:
The list of transpiled SQL statements.
"""
- write = write or read if identity else write
+ write = (read if write is None else write) if identity else write
return [
Dialect.get_or_raise(write)().generate(expression, **opts)
for expression in parse(sql, read, error_level=error_level)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 993d869..71385aa 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -747,11 +747,11 @@ def ascii(col: ColumnOrLiteral) -> Column:
def base64(col: ColumnOrLiteral) -> Column:
- return Column.invoke_anonymous_function(col, "BASE64")
+ return Column.invoke_expression_over_column(col, expression.ToBase64)
def unbase64(col: ColumnOrLiteral) -> Column:
- return Column.invoke_anonymous_function(col, "UNBASE64")
+ return Column.invoke_expression_over_column(col, expression.FromBase64)
def ltrim(col: ColumnOrName) -> Column:
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 191e703..fc34262 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
+from sqlglot.dialects.spark2 import Spark2
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 1a88654..9705b35 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -39,18 +39,26 @@ def _date_add_sql(
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
- expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression)
- rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
- structs = []
- for row in rows:
- aliases = [
- exp.alias_(value, column_name)
- for value, column_name in zip(row, expression.args["alias"].args["columns"])
- ]
- structs.append(exp.Struct(expressions=aliases))
- unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
- return self.unnest_sql(unnest_exp)
+
+ alias = expression.args.get("alias")
+
+ structs = [
+ exp.Struct(
+ expressions=[
+ exp.alias_(value, column_name)
+ for value, column_name in zip(
+ t.expressions,
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(t.expressions))),
+ )
+ ]
+ )
+ for t in expression.find_all(exp.Tuple)
+ ]
+
+ return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
@@ -128,6 +136,7 @@ class BigQuery(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
+ BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -139,6 +148,7 @@ class BigQuery(Dialect):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
+ "BYTES": TokenType.BINARY,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@@ -153,7 +163,7 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
- unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
+ unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
@@ -206,6 +216,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
+ "OPTIONS": lambda self: self._parse_with_property(),
+ }
+
+ CONSTRAINT_PARSERS = {
+ **parser.Parser.CONSTRAINT_PARSERS, # type: ignore
+ "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
class Generator(generator.Generator):
@@ -217,11 +233,11 @@ class BigQuery(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
+ exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@@ -234,7 +250,9 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
- exp.Select: transforms.preprocess([_unqualify_unnest]),
+ exp.Select: transforms.preprocess(
+ [_unqualify_unnest, transforms.eliminate_distinct_on]
+ ),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@@ -259,6 +277,7 @@ class BigQuery(Dialect):
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.BINARY: "BYTES",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
@@ -272,6 +291,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
+ exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
@@ -310,3 +330,6 @@ class BigQuery(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("OPTIONS"))
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index e91b0bf..2a49066 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -22,6 +22,8 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
+ BIT_STRINGS = [("0b", "")]
+ HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -31,10 +33,18 @@ class ClickHouse(Dialect):
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
+ "INT8": TokenType.TINYINT,
+ "UINT8": TokenType.UTINYINT,
"INT16": TokenType.SMALLINT,
+ "UINT16": TokenType.USMALLINT,
"INT32": TokenType.INT,
+ "UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
- "INT8": TokenType.TINYINT,
+ "UINT64": TokenType.UBIGINT,
+ "INT128": TokenType.INT128,
+ "UINT128": TokenType.UINT128,
+ "INT256": TokenType.INT256,
+ "UINT256": TokenType.UINT256,
"TUPLE": TokenType.STRUCT,
}
@@ -121,9 +131,17 @@ class ClickHouse(Dialect):
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
+ exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.SMALLINT: "Int16",
+ exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.INT: "Int32",
+ exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.BIGINT: "Int64",
+ exp.DataType.Type.UBIGINT: "UInt64",
+ exp.DataType.Type.INT128: "Int128",
+ exp.DataType.Type.UINT128: "UInt128",
+ exp.DataType.Type.INT256: "Int256",
+ exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
}
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 138f26c..51112a0 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp
+from sqlglot import exp, transforms
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
@@ -29,13 +29,20 @@ class Databricks(Spark):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
+ exp.Select: transforms.preprocess(
+ [
+ transforms.eliminate_distinct_on,
+ transforms.unnest_to_explode,
+ ]
+ ),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
- TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
PARAMETER_TOKEN = "$"
class Tokenizer(Spark.Tokenizer):
+ HEX_STRINGS = []
+
SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 19c6f73..71269f2 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -28,6 +28,7 @@ class Dialects(str, Enum):
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
+ SPARK2 = "spark2"
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
@@ -69,30 +70,17 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
- if (
- klass.tokenizer_class._BIT_STRINGS
- and exp.BitString not in klass.generator_class.TRANSFORMS
- ):
- bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
- klass.generator_class.TRANSFORMS[
- exp.BitString
- ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
- if (
- klass.tokenizer_class._HEX_STRINGS
- and exp.HexString not in klass.generator_class.TRANSFORMS
- ):
- hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
- klass.generator_class.TRANSFORMS[
- exp.HexString
- ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
- if (
- klass.tokenizer_class._BYTE_STRINGS
- and exp.ByteString not in klass.generator_class.TRANSFORMS
- ):
- be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
- klass.generator_class.TRANSFORMS[
- exp.ByteString
- ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
+ klass.bit_start, klass.bit_end = seq_get(
+ list(klass.tokenizer_class._BIT_STRINGS.items()), 0
+ ) or (None, None)
+
+ klass.hex_start, klass.hex_end = seq_get(
+ list(klass.tokenizer_class._HEX_STRINGS.items()), 0
+ ) or (None, None)
+
+ klass.byte_start, klass.byte_end = seq_get(
+ list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
+ ) or (None, None)
return klass
@@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect):
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
+ "bit_start": self.bit_start,
+ "bit_end": self.bit_end,
+ "hex_start": self.hex_start,
+ "hex_end": self.hex_end,
+ "byte_start": self.byte_start,
+ "byte_end": self.byte_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index d7e2d88..7ad555e 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
@@ -145,6 +145,7 @@ class Drill(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.Pow: rename_func("POW"),
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 9454db6..bce956e 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@@ -23,52 +25,61 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _ts_or_ds_add(self, expression):
+def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
-def _date_add(self, expression):
+def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
-def _array_sort_sql(self, expression):
+def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
-def _sort_array_sql(self, expression):
+def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
-def _sort_array_reverse(args):
+def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
-def _struct_sql(self, expression):
+def _parse_date_diff(args: t.Sequence) -> exp.Expression:
+ return exp.DateDiff(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
+ )
+
+
+def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
-def _regexp_extract_sql(self, expression):
+def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
+
return self.func(
"REGEXP_EXTRACT",
expression.args.get("this"),
@@ -108,6 +119,8 @@ class DuckDB(Dialect):
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
+ "DATEDIFF": _parse_date_diff,
+ "DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
@@ -115,18 +128,18 @@ class DuckDB(Dialect):
expression=exp.Literal.number(1000),
)
),
- "LIST_SORT": exp.SortArray.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
+ "LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
- "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
- "STR_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT": exp.Split.from_arg_list,
- "STRING_TO_ARRAY": exp.Split.from_arg_list,
- "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
+ "STRING_TO_ARRAY": exp.Split.from_arg_list,
+ "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STRUCT_PACK": exp.Struct.from_arg_list,
+ "STR_SPLIT": exp.Split.from_arg_list,
+ "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
@@ -142,10 +155,11 @@ class DuckDB(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
@@ -154,13 +168,16 @@ class DuckDB(Dialect):
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
+ exp.CurrentDate: lambda self, e: "CURRENT_DATE",
+ exp.CurrentTime: lambda self, e: "CURRENT_TIME",
+ exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
- exp.DateAdd: _date_add,
+ exp.DateAdd: _date_add_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
+ "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
@@ -192,7 +209,7 @@ class DuckDB(Dialect):
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
- exp.TsOrDsAdd: _ts_or_ds_add,
+ exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
@@ -201,7 +218,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
@@ -212,17 +229,14 @@ class DuckDB(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
}
- STAR_MAPPING = {
- **generator.Generator.STAR_MAPPING,
- "except": "EXCLUDE",
- }
+ STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- LIMIT_FETCH = "LIMIT"
-
- def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
- return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
+ def tablesample_sql(
+ self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ ) -> str:
+ return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 6746fcf..871a180 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
-def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
+def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
+ this = expression.this
+
+ if not this.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ annotate_types(this)
+
+ if this.type.is_type(exp.DataType.Type.JSON):
+ return self.sql(this)
+ return self.func("TO_JSON", this, expression.args.get("options"))
+
+
+def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
@@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
-def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
+def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
-def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
return f"CAST({this} AS DATE)"
-def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
+def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -214,6 +227,7 @@ class Hive(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
@@ -251,6 +265,7 @@ class Hive(Dialect):
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
+ "UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@@ -280,16 +295,20 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
- **transforms.ELIMINATE_QUALIFY, # type: ignore
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
- [transforms.eliminate_qualify, transforms.unnest_to_explode]
+ [
+ transforms.eliminate_qualify,
+ transforms.eliminate_distinct_on,
+ transforms.unnest_to_explode,
+ ]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
+ exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
- exp.ArraySort: _array_sort,
+ exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
@@ -298,12 +317,13 @@ class Hive(Dialect):
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
+ exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
- exp.JSONFormat: rename_func("TO_JSON"),
+ exp.JSONFormat: _json_format_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@@ -318,9 +338,9 @@ class Hive(Dialect):
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
- exp.StrToDate: _str_to_date,
- exp.StrToTime: _str_to_time,
- exp.StrToUnix: _str_to_unix,
+ exp.StrToDate: _str_to_date_sql,
+ exp.StrToTime: _str_to_time_sql,
+ exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
@@ -328,6 +348,7 @@ class Hive(Dialect):
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 666e740..5342624 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -403,6 +403,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 9ccd02e..c8af1c6 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:
class Oracle(Dialect):
+ alias_post_tablesample = True
+
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
@@ -121,21 +123,23 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
+ exp.IfNull: rename_func("NVL"),
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
+ exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
- exp.IfNull: rename_func("NVL"),
}
PROPERTIES_LOCATION = {
@@ -164,14 +168,19 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
+ VAR_SINGLE_TOKENS = {"@"}
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
+ "BINARY_DOUBLE": TokenType.DOUBLE,
+ "BINARY_FLOAT": TokenType.FLOAT,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
+ "SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c47ff51..2132778 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -1,6 +1,8 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, tokens
+import typing as t
+
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
-from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@@ -274,8 +275,7 @@ class Postgres(Dialect):
TokenType.HASH: exp.BitwiseXor,
}
- FACTOR = {
- **parser.Parser.FACTOR,
+ EXPONENT = {
TokenType.CARET: exp.Pow,
}
@@ -286,6 +286,12 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
+ def _parse_factor(self) -> t.Optional[exp.Expression]:
+ return self._parse_tokens(self._parse_exponent, self.FACTOR)
+
+ def _parse_exponent(self) -> t.Optional[exp.Expression]:
+ return self._parse_tokens(self._parse_unary, self.EXPONENT)
+
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
self._match(TokenType.COMMA)
@@ -316,7 +322,7 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
- exp.ColumnDef: preprocess(
+ exp.ColumnDef: transforms.preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
@@ -341,7 +347,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
- exp.Merge: preprocess([remove_target_from_merge]),
+ exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 489d439..6133a27 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
- step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
+ step = expression.args.get("step")
target_type = None
@@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
else:
start = exp.Cast(this=start, to=to)
- return self.func("SEQUENCE", start, end, step)
+ sql = self.func("SEQUENCE", start, end, step)
+ if isinstance(expression.parent, exp.Table):
+ sql = f"UNNEST({sql})"
+
+ return sql
def _ensure_utf8(charset: exp.Literal) -> None:
@@ -204,6 +208,7 @@ class Presto(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
@@ -219,23 +224,23 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
+ "FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
+ ),
"NOW": exp.CurrentTimestamp.from_arg_list,
+ "SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
- "APPROX_PERCENTILE": _approx_percentile,
- "FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
- "FROM_UTF8": lambda args: exp.Decode(
- this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
- ),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
@@ -264,7 +269,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@@ -290,6 +294,7 @@ class Presto(Dialect):
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
@@ -303,7 +308,11 @@ class Presto(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_qualify, transforms.explode_to_unnest]
+ [
+ transforms.eliminate_qualify,
+ transforms.eliminate_distinct_on,
+ transforms.explode_to_unnest,
+ ]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
@@ -327,6 +336,9 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
+ exp.WithinGroup: transforms.preprocess(
+ [transforms.remove_within_group_for_percentiles]
+ ),
}
def interval_sql(self, expression: exp.Interval) -> str:
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index a9c4f62..1b7cf31 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -52,6 +52,8 @@ class Redshift(Postgres):
return this
class Tokenizer(Postgres.Tokenizer):
+ BIT_STRINGS = []
+ HEX_STRINGS = []
STRING_ESCAPES = ["\\"]
KEYWORDS = {
@@ -90,7 +92,6 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
- **transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@@ -102,6 +103,7 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 0829669..70dcaa9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
date_trunc_to_time,
@@ -252,6 +252,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -305,6 +306,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index a3e4cce..939f2fd 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -2,222 +2,54 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parser
-from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
-from sqlglot.dialects.hive import Hive
+from sqlglot import exp
+from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
-def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
- kind = e.args["kind"]
- properties = e.args.get("properties")
+def _parse_datediff(args: t.Sequence) -> exp.Expression:
+ """
+ Although Spark docs don't mention the "unit" argument, Spark3 added support for
+ it at some point. Databricks also supports this variation (see below).
- if kind.upper() == "TABLE" and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- ):
- return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
- return create_with_partitions_sql(self, e)
+ For example, in spark-sql (v3.3.1):
+ - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
+ - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
+ See also:
+ - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
+ - https://docs.databricks.com/sql/language-manual/functions/datediff.html
+ """
+ unit = None
+ this = seq_get(args, 0)
+ expression = seq_get(args, 1)
-def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
- keys = self.sql(expression.args["keys"])
- values = self.sql(expression.args["values"])
- return f"MAP_FROM_ARRAYS({keys}, {values})"
+ if len(args) == 3:
+ unit = this
+ this = args[2]
+ return exp.DateDiff(
+ this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
+ )
-def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
- this = self.sql(expression, "this")
- time_format = self.format_time(expression)
- if time_format == Hive.date_format:
- return f"TO_DATE({this})"
- return f"TO_DATE({this}, {time_format})"
-
-def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
- scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
- if scale is None:
- return f"FROM_UNIXTIME({timestamp})"
- if scale == exp.UnixToTime.SECONDS:
- return f"TIMESTAMP_SECONDS({timestamp})"
- if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
- if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
-
- raise ValueError("Improper scale for timestamp")
-
-
-class Spark(Hive):
- class Parser(Hive.Parser):
+class Spark(Spark2):
+ class Parser(Spark2.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS, # type: ignore
- "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
- "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
- "LEFT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Literal.number(1),
- length=seq_get(args, 1),
- ),
- "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "RIGHT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Sub(
- this=exp.Length(this=seq_get(args, 0)),
- expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
- ),
- length=seq_get(args, 1),
- ),
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "BOOLEAN": lambda args: exp.Cast(
- this=seq_get(args, 0), to=exp.DataType.build("boolean")
- ),
- "IIF": exp.If.from_arg_list,
- "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
- "AGGREGATE": exp.Reduce.from_arg_list,
- "DAYOFWEEK": lambda args: exp.DayOfWeek(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFYEAR": lambda args: exp.DayOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "WEEKOFYEAR": lambda args: exp.WeekOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1),
- unit=exp.var(seq_get(args, 0)),
- ),
- "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
- "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
- "TIMESTAMP": lambda args: exp.Cast(
- this=seq_get(args, 0), to=exp.DataType.build("timestamp")
- ),
- }
-
- FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
- "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
- "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
- "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
- "MERGE": lambda self: self._parse_join_hint("MERGE"),
- "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
- "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
- "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
- "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
- }
-
- def _parse_add_column(self) -> t.Optional[exp.Expression]:
- return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
-
- def _parse_drop_column(self) -> t.Optional[exp.Expression]:
- return self._match_text_seq("DROP", "COLUMNS") and self.expression(
- exp.Drop,
- this=self._parse_schema(),
- kind="COLUMNS",
- )
-
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
- if len(pivot_columns) == 1:
- return [""]
-
- names = []
- for agg in pivot_columns:
- if isinstance(agg, exp.Alias):
- names.append(agg.alias)
- else:
- """
- This case corresponds to aggregations without aliases being used as suffixes
- (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
- be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
- Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
-
- Moreover, function names are lowercased in order to mimic Spark's naming scheme.
- """
- agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
- )
- names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
-
- return names
-
- class Generator(Hive.Generator):
- TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.TINYINT: "BYTE",
- exp.DataType.Type.SMALLINT: "SHORT",
- exp.DataType.Type.BIGINT: "LONG",
- }
-
- PROPERTIES_LOCATION = {
- **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
- exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
- exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
- exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
- }
-
- TRANSFORMS = {
- **Hive.Generator.TRANSFORMS, # type: ignore
- exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
- exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
- exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
- exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
- exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
- exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
- exp.StrToDate: _str_to_date,
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time_sql,
- exp.Create: _create_sql,
- exp.Map: _map_sql,
- exp.Reduce: rename_func("AGGREGATE"),
- exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
- exp.Trim: trim_sql,
- exp.VariancePop: rename_func("VAR_POP"),
- exp.DateFromParts: rename_func("MAKE_DATE"),
- exp.LogicalOr: rename_func("BOOL_OR"),
- exp.LogicalAnd: rename_func("BOOL_AND"),
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.DayOfMonth: rename_func("DAYOFMONTH"),
- exp.DayOfYear: rename_func("DAYOFYEAR"),
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
- exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ **Spark2.Parser.FUNCTIONS, # type: ignore
+ "DATEDIFF": _parse_datediff,
}
- TRANSFORMS.pop(exp.ArraySort)
- TRANSFORMS.pop(exp.ILike)
- WRAP_DERIVED_VALUES = False
- CREATE_FUNCTION_RETURN_AS = False
+ class Generator(Spark2.Generator):
+ TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
+ TRANSFORMS.pop(exp.DateDiff)
- def cast_sql(self, expression: exp.Cast) -> str:
- if isinstance(expression.this, exp.Cast) and expression.this.is_type(
- exp.DataType.Type.JSON
- ):
- schema = f"'{self.sql(expression, 'to')}'"
- return self.func("FROM_JSON", expression.this.this, schema)
- if expression.to.is_type(exp.DataType.Type.JSON):
- return self.func("TO_JSON", expression.this)
+ def datediff_sql(self, expression: exp.DateDiff) -> str:
+ unit = self.sql(expression, "unit")
+ end = self.sql(expression, "this")
+ start = self.sql(expression, "expression")
- return super(Spark.Generator, self).cast_sql(expression)
+ if unit:
+ return self.func("DATEDIFF", unit, start, end)
- class Tokenizer(Hive.Tokenizer):
- HEX_STRINGS = [("X'", "'")]
+ return self.func("DATEDIFF", end, start)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
new file mode 100644
index 0000000..584671f
--- /dev/null
+++ b/sqlglot/dialects/spark2.py
@@ -0,0 +1,238 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp, parser, transforms
+from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
+from sqlglot.dialects.hive import Hive
+from sqlglot.helper import seq_get
+
+
+def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
+ kind = e.args["kind"]
+ properties = e.args.get("properties")
+
+ if kind.upper() == "TABLE" and any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ ):
+ return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
+ return create_with_partitions_sql(self, e)
+
+
+def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
+ keys = self.sql(expression.args["keys"])
+ values = self.sql(expression.args["values"])
+ return f"MAP_FROM_ARRAYS({keys}, {values})"
+
+
+def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
+ return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
+
+
+def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format == Hive.date_format:
+ return f"TO_DATE({this})"
+ return f"TO_DATE({this}, {time_format})"
+
+
+def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale is None:
+ return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
+ if scale == exp.UnixToTime.SECONDS:
+ return f"TIMESTAMP_SECONDS({timestamp})"
+ if scale == exp.UnixToTime.MILLIS:
+ return f"TIMESTAMP_MILLIS({timestamp})"
+ if scale == exp.UnixToTime.MICROS:
+ return f"TIMESTAMP_MICROS({timestamp})"
+
+ raise ValueError("Improper scale for timestamp")
+
+
+class Spark2(Hive):
+ class Parser(Hive.Parser):
+ FUNCTIONS = {
+ **Hive.Parser.FUNCTIONS, # type: ignore
+ "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
+ "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "LEFT": lambda args: exp.Substring(
+ this=seq_get(args, 0),
+ start=exp.Literal.number(1),
+ length=seq_get(args, 1),
+ ),
+ "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ ),
+ "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ ),
+ "RIGHT": lambda args: exp.Substring(
+ this=seq_get(args, 0),
+ start=exp.Sub(
+ this=exp.Length(this=seq_get(args, 0)),
+ expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
+ ),
+ length=seq_get(args, 1),
+ ),
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "IIF": exp.If.from_arg_list,
+ "AGGREGATE": exp.Reduce.from_arg_list,
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1),
+ unit=exp.var(seq_get(args, 0)),
+ ),
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "BOOLEAN": _parse_as_cast("boolean"),
+ "DOUBLE": _parse_as_cast("double"),
+ "FLOAT": _parse_as_cast("float"),
+ "INT": _parse_as_cast("int"),
+ "STRING": _parse_as_cast("string"),
+ "TIMESTAMP": _parse_as_cast("timestamp"),
+ }
+
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
+ "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
+ "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
+ "MERGE": lambda self: self._parse_join_hint("MERGE"),
+ "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
+ "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
+ "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
+ "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
+ }
+
+ def _parse_add_column(self) -> t.Optional[exp.Expression]:
+ return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
+
+ def _parse_drop_column(self) -> t.Optional[exp.Expression]:
+ return self._match_text_seq("DROP", "COLUMNS") and self.expression(
+ exp.Drop,
+ this=self._parse_schema(),
+ kind="COLUMNS",
+ )
+
+ def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
+ # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
+ if len(pivot_columns) == 1:
+ return [""]
+
+ names = []
+ for agg in pivot_columns:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+
+ Moreover, function names are lowercased in order to mimic Spark's naming scheme.
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
+
+ return names
+
+ class Generator(Hive.Generator):
+ TYPE_MAPPING = {
+ **Hive.Generator.TYPE_MAPPING, # type: ignore
+ exp.DataType.Type.TINYINT: "BYTE",
+ exp.DataType.Type.SMALLINT: "SHORT",
+ exp.DataType.Type.BIGINT: "LONG",
+ }
+
+ PROPERTIES_LOCATION = {
+ **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
+ TRANSFORMS = {
+ **Hive.Generator.TRANSFORMS, # type: ignore
+ exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
+ exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
+ exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
+ exp.Create: _create_sql,
+ exp.DateFromParts: rename_func("MAKE_DATE"),
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
+ exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
+ exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.Map: _map_sql,
+ exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
+ exp.Reduce: rename_func("AGGREGATE"),
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimestampTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
+ ),
+ exp.Trim: trim_sql,
+ exp.UnixToTime: _unix_to_time_sql,
+ exp.VariancePop: rename_func("VAR_POP"),
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
+ exp.WithinGroup: transforms.preprocess(
+ [transforms.remove_within_group_for_percentiles]
+ ),
+ }
+ TRANSFORMS.pop(exp.ArrayJoin)
+ TRANSFORMS.pop(exp.ArraySort)
+ TRANSFORMS.pop(exp.ILike)
+
+ WRAP_DERIVED_VALUES = False
+ CREATE_FUNCTION_RETURN_AS = False
+
+ def cast_sql(self, expression: exp.Cast) -> str:
+ if isinstance(expression.this, exp.Cast) and expression.this.is_type(
+ exp.DataType.Type.JSON
+ ):
+ schema = f"'{self.sql(expression, 'to')}'"
+ return self.func("FROM_JSON", expression.this.this, schema)
+ if expression.to.is_type(exp.DataType.Type.JSON):
+ return self.func("TO_JSON", expression.this)
+
+ return super(Hive.Generator, self).cast_sql(expression)
+
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
+ return super().columndef_sql(
+ expression,
+ sep=": "
+ if isinstance(expression.parent, exp.DataType)
+ and expression.parent.is_type(exp.DataType.Type.STRUCT)
+ else sep,
+ )
+
+ class Tokenizer(Hive.Tokenizer):
+ HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 4437f82..f2efe32 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
return self.func("DATE", expression.this, modifier)
+def _transform_create(expression: exp.Expression) -> exp.Expression:
+ """Move primary key to a column and enforce auto_increment on primary keys."""
+ schema = expression.this
+
+ if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
+ defs = {}
+ primary_key = None
+
+ for e in schema.expressions:
+ if isinstance(e, exp.ColumnDef):
+ defs[e.name] = e
+ elif isinstance(e, exp.PrimaryKey):
+ primary_key = e
+
+ if primary_key and len(primary_key.expressions) == 1:
+ column = defs[primary_key.expressions[0].name]
+ column.append(
+ "constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
+ )
+ schema.expressions.remove(primary_key)
+ else:
+ for column in defs.values():
+ auto_increment = None
+ for constraint in column.constraints.copy():
+ if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
+ break
+ if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
+ auto_increment = constraint
+ if auto_increment:
+ column.constraints.remove(auto_increment)
+
+ return expression
+
+
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@@ -65,8 +99,8 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.CountIf: count_if_to_sum,
+ exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
@@ -80,14 +114,17 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
+ exp.Select: transforms.preprocess(
+ [transforms.eliminate_distinct_on, transforms.eliminate_qualify]
+ ),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ k: exp.Properties.Location.UNSUPPORTED
+ for k, v in generator.Generator.PROPERTIES_LOCATION.items()
}
LIMIT_FETCH = "LIMIT"
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index ff19dab..895588a 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -34,6 +34,7 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
+ exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 792c2b4..51e685b 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser
+from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
@@ -29,6 +29,7 @@ class Tableau(Dialect):
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 331e105..a79eaeb 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@@ -148,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 9cf56e1..03de99c 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import re
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
@@ -259,8 +259,8 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
-
QUOTES = ["'", '"']
+ HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -463,17 +463,18 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
- exp.If: rename_func("IIF"),
- exp.NumberToStr: _format_sql,
- exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
+ exp.If: rename_func("IIF"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
+ exp.NumberToStr: _format_sql,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
+ exp.TimeToStr: _format_sql,
}
TRANSFORMS.pop(exp.ReturnsProperty)
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 49d3ff6..9e7379d 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -64,6 +64,13 @@ class Expression(metaclass=_Expression):
and representing expressions as strings.
arg_types: determines what arguments (child nodes) are supported by an expression. It
maps arg keys to booleans that indicate whether the corresponding args are optional.
+ parent: a reference to the parent expression (or None, in case of root expressions).
+ arg_key: the arg key an expression is associated with, i.e. the name its parent expression
+ uses to refer to it.
+ comments: a list of comments that are associated with a given expression. This is used in
+ order to preserve comments when transpiling SQL code.
+ _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
+ optimizer, in order to enable some transformations that require type information.
Example:
>>> class Foo(Expression):
@@ -74,13 +81,6 @@ class Expression(metaclass=_Expression):
Args:
args: a mapping used for retrieving the arguments of an expression, given their arg keys.
- parent: a reference to the parent expression (or None, in case of root expressions).
- arg_key: the arg key an expression is associated with, i.e. the name its parent expression
- uses to refer to it.
- comments: a list of comments that are associated with a given expression. This is used in
- order to preserve comments when transpiling SQL code.
- _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
- optimizer, in order to enable some transformations that require type information.
"""
key = "expression"
@@ -258,6 +258,12 @@ class Expression(metaclass=_Expression):
new.parent = self.parent
return new
+ def add_comments(self, comments: t.Optional[t.List[str]]) -> None:
+ if self.comments is None:
+ self.comments = []
+ if comments:
+ self.comments.extend(comments)
+
def append(self, arg_key, value):
"""
Appends value to arg_key if it's a list or sets it as a new list.
@@ -650,7 +656,7 @@ ExpOrStr = t.Union[str, Expression]
class Condition(Expression):
- def and_(self, *expressions, dialect=None, **opts):
+ def and_(self, *expressions, dialect=None, copy=True, **opts):
"""
AND this condition with one or multiple expressions.
@@ -662,14 +668,15 @@ class Condition(Expression):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
And: the new condition.
"""
- return and_(self, *expressions, dialect=dialect, **opts)
+ return and_(self, *expressions, dialect=dialect, copy=copy, **opts)
- def or_(self, *expressions, dialect=None, **opts):
+ def or_(self, *expressions, dialect=None, copy=True, **opts):
"""
OR this condition with one or multiple expressions.
@@ -681,14 +688,15 @@ class Condition(Expression):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.
Returns:
Or: the new condition.
"""
- return or_(self, *expressions, dialect=dialect, **opts)
+ return or_(self, *expressions, dialect=dialect, copy=copy, **opts)
- def not_(self):
+ def not_(self, copy=True):
"""
Wrap this condition with NOT.
@@ -696,14 +704,17 @@ class Condition(Expression):
>>> condition("x=1").not_().sql()
'NOT x = 1'
+ Args:
+ copy (bool): whether or not to copy this object.
+
Returns:
Not: the new condition.
"""
- return not_(self)
+ return not_(self, copy=copy)
def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
- this = self
- other = convert(other)
+ this = self.copy()
+ other = convert(other, copy=True)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
@@ -711,20 +722,25 @@ class Condition(Expression):
return klass(this=other, expression=this)
return klass(this=this, expression=other)
- def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
- if isinstance(other, slice):
- return Between(
- this=self,
- low=convert(other.start),
- high=convert(other.stop),
- )
- return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
+ def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
+ return Bracket(
+ this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
+ )
- def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
+ def isin(
+ self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
+ ) -> In:
return In(
- this=self,
- expressions=[convert(e) for e in expressions],
- query=maybe_parse(query, **opts) if query else None,
+ this=_maybe_copy(self, copy),
+ expressions=[convert(e, copy=copy) for e in expressions],
+ query=maybe_parse(query, copy=copy, **opts) if query else None,
+ )
+
+ def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
+ return Between(
+ this=_maybe_copy(self, copy),
+ low=convert(low, copy=copy, **opts),
+ high=convert(high, copy=copy, **opts),
)
def like(self, other: ExpOrStr) -> Like:
@@ -809,10 +825,10 @@ class Condition(Expression):
return self._binop(Or, other, reverse=True)
def __neg__(self) -> Neg:
- return Neg(this=_wrap(self, Binary))
+ return Neg(this=_wrap(self.copy(), Binary))
def __invert__(self) -> Not:
- return not_(self)
+ return not_(self.copy())
class Predicate(Condition):
@@ -830,11 +846,7 @@ class DerivedTable(Expression):
@property
def selects(self):
- alias = self.args.get("alias")
-
- if alias:
- return alias.columns
- return []
+ return self.this.selects if isinstance(self.this, Subqueryable) else []
@property
def named_selects(self):
@@ -904,7 +916,10 @@ class Unionable(Expression):
class UDTF(DerivedTable, Unionable):
- pass
+ @property
+ def selects(self):
+ alias = self.args.get("alias")
+ return alias.columns if alias else []
class Cache(Expression):
@@ -1073,6 +1088,10 @@ class ColumnDef(Expression):
"position": False,
}
+ @property
+ def constraints(self) -> t.List[ColumnConstraint]:
+ return self.args.get("constraints") or []
+
class AlterColumn(Expression):
arg_types = {
@@ -1100,6 +1119,10 @@ class Comment(Expression):
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
+ @property
+ def kind(self) -> ColumnConstraintKind:
+ return self.args["kind"]
+
class ColumnConstraintKind(Expression):
pass
@@ -1937,6 +1960,15 @@ class Reference(Expression):
class Tuple(Expression):
arg_types = {"expressions": False}
+ def isin(
+ self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
+ ) -> In:
+ return In(
+ this=_maybe_copy(self, copy),
+ expressions=[convert(e, copy=copy) for e in expressions],
+ query=maybe_parse(query, copy=copy, **opts) if query else None,
+ )
+
class Subqueryable(Unionable):
def subquery(self, alias=None, copy=True) -> Subquery:
@@ -2236,6 +2268,8 @@ class Select(Subqueryable):
"expressions": False,
"hint": False,
"distinct": False,
+ "struct": False, # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#return_query_results_as_a_value_table
+ "value": False,
"into": False,
"from": False,
**QUERY_MODIFIERS,
@@ -2611,7 +2645,7 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
- on = and_(*ensure_collection(on), dialect=dialect, **opts)
+ on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
join.set("on", on)
if using:
@@ -2723,7 +2757,7 @@ class Select(Subqueryable):
**opts,
)
- def distinct(self, distinct=True, copy=True) -> Select:
+ def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select:
"""
Set the OFFSET expression.
@@ -2732,14 +2766,16 @@ class Select(Subqueryable):
'SELECT DISTINCT x FROM tbl'
Args:
- distinct (bool): whether the Select should be distinct
- copy (bool): if `False`, modify this expression instance in-place.
+ ons: the expressions to distinct on
+ distinct: whether the Select should be distinct
+ copy: if `False`, modify this expression instance in-place.
Returns:
Select: the modified expression.
"""
instance = _maybe_copy(self, copy)
- instance.set("distinct", Distinct() if distinct else None)
+ on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None
+ instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
@@ -2969,6 +3005,10 @@ class DataType(Expression):
USMALLINT = auto()
BIGINT = auto()
UBIGINT = auto()
+ INT128 = auto()
+ UINT128 = auto()
+ INT256 = auto()
+ UINT256 = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
@@ -3022,6 +3062,8 @@ class DataType(Expression):
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
+ Type.INT128,
+ Type.INT256,
}
FLOAT_TYPES = {
@@ -3069,10 +3111,6 @@ class PseudoType(Expression):
pass
-class StructKwarg(Expression):
- arg_types = {"this": True, "expression": True}
-
-
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
class SubqueryPredicate(Predicate):
pass
@@ -3538,14 +3576,20 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
- this = self.copy() if copy else self
- this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
- return this
+ instance = _maybe_copy(self, copy)
+ instance.append(
+ "ifs",
+ If(
+ this=maybe_parse(condition, copy=copy, **opts),
+ true=maybe_parse(then, copy=copy, **opts),
+ ),
+ )
+ return instance
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
- this = self.copy() if copy else self
- this.set("default", maybe_parse(condition, **opts))
- return this
+ instance = _maybe_copy(self, copy)
+ instance.set("default", maybe_parse(condition, copy=copy, **opts))
+ return instance
class Cast(Func):
@@ -3760,6 +3804,14 @@ class Floor(Func):
arg_types = {"this": True, "decimals": False}
+class FromBase64(Func):
+ pass
+
+
+class ToBase64(Func):
+ pass
+
+
class Greatest(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -3930,11 +3982,11 @@ class Pow(Binary, Func):
class PercentileCont(AggFunc):
- pass
+ arg_types = {"this": True, "expression": False}
class PercentileDisc(AggFunc):
- pass
+ arg_types = {"this": True, "expression": False}
class Quantile(AggFunc):
@@ -4405,14 +4457,16 @@ def _apply_conjunction_builder(
if append and existing is not None:
expressions = [existing.this if into else existing] + list(expressions)
- node = and_(*expressions, dialect=dialect, **opts)
+ node = and_(*expressions, dialect=dialect, copy=copy, **opts)
inst.set(arg, into(this=node) if into else node)
return inst
-def _combine(expressions, operator, dialect=None, **opts):
- expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
+def _combine(expressions, operator, dialect=None, copy=True, **opts):
+ expressions = [
+ condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
+ ]
this = expressions[0]
if expressions[1:]:
this = _wrap(this, Connector)
@@ -4626,7 +4680,7 @@ def delete(
return delete_expr
-def condition(expression, dialect=None, **opts) -> Condition:
+def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.
@@ -4645,6 +4699,7 @@ def condition(expression, dialect=None, **opts) -> Condition:
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
+ copy (bool): Whether or not to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
@@ -4655,11 +4710,12 @@ def condition(expression, dialect=None, **opts) -> Condition:
expression,
into=Condition,
dialect=dialect,
+ copy=copy,
**opts,
)
-def and_(*expressions, dialect=None, **opts) -> And:
+def and_(*expressions, dialect=None, copy=True, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@@ -4671,15 +4727,16 @@ def and_(*expressions, dialect=None, **opts) -> And:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
And: the new condition
"""
- return _combine(expressions, And, dialect, **opts)
+ return _combine(expressions, And, dialect, copy=copy, **opts)
-def or_(*expressions, dialect=None, **opts) -> Or:
+def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@@ -4691,15 +4748,16 @@ def or_(*expressions, dialect=None, **opts) -> Or:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
+ copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
Or: the new condition
"""
- return _combine(expressions, Or, dialect, **opts)
+ return _combine(expressions, Or, dialect, copy=copy, **opts)
-def not_(expression, dialect=None, **opts) -> Not:
+def not_(expression, dialect=None, copy=True, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@@ -4719,13 +4777,14 @@ def not_(expression, dialect=None, **opts) -> Not:
this = condition(
expression,
dialect=dialect,
+ copy=copy,
**opts,
)
return Not(this=_wrap(this, Connector))
-def paren(expression) -> Paren:
- return Paren(this=expression)
+def paren(expression, copy=True) -> Paren:
+ return Paren(this=_maybe_copy(expression, copy))
SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
@@ -4998,29 +5057,20 @@ def values(
alias: optional alias
columns: Optional list of ordered column names or ordered dictionary of column names to types.
If either are provided then an alias is also required.
- If a dictionary is provided then the first column of the values will be casted to the expected type
- in order to help with type inference.
Returns:
Values: the Values expression object
"""
if columns and not alias:
raise ValueError("Alias is required when providing columns")
- table_alias = (
- TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
- if columns
- else TableAlias(this=to_identifier(alias) if alias else None)
- )
- expressions = [convert(tup) for tup in values]
- if columns and isinstance(columns, dict):
- types = list(columns.values())
- expressions[0].set(
- "expressions",
- [cast(x, types[i]) for i, x in enumerate(expressions[0].expressions)],
- )
+
return Values(
- expressions=expressions,
- alias=table_alias,
+ expressions=[convert(tup) for tup in values],
+ alias=(
+ TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns])
+ if columns
+ else (TableAlias(this=to_identifier(alias)) if alias else None)
+ ),
)
@@ -5068,19 +5118,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)
-def convert(value) -> Expression:
+def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
Args:
- value (Any): a python object
+ value: A python object.
+ copy: Whether or not to copy `value` (only applies to Expressions and collections).
Returns:
- Expression: the equivalent expression object
+ Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
- return value
+ return _maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):
@@ -5098,13 +5149,13 @@ def convert(value) -> Expression:
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
- return Tuple(expressions=[convert(v) for v in value])
+ return Tuple(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, list):
- return Array(expressions=[convert(v) for v in value])
+ return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
- keys=[convert(k) for k in value],
- values=[convert(v) for v in value.values()],
+ keys=[convert(k, copy=copy) for k in value],
+ values=[convert(v, copy=copy) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index bd12d54..d7dcea0 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -25,6 +25,12 @@ class Generator:
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
+ bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
+ bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
+ hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
+ hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
+ byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
+ byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
@@ -227,6 +233,12 @@ class Generator:
"quote_end",
"identifier_start",
"identifier_end",
+ "bit_start",
+ "bit_end",
+ "hex_start",
+ "hex_end",
+ "byte_start",
+ "byte_end",
"identify",
"normalize",
"string_escape",
@@ -258,6 +270,12 @@ class Generator:
quote_end=None,
identifier_start=None,
identifier_end=None,
+ bit_start=None,
+ bit_end=None,
+ hex_start=None,
+ hex_end=None,
+ byte_start=None,
+ byte_end=None,
identify=False,
normalize=False,
string_escape=None,
@@ -284,6 +302,12 @@ class Generator:
self.quote_end = quote_end or "'"
self.identifier_start = identifier_start or '"'
self.identifier_end = identifier_end or '"'
+ self.bit_start = bit_start
+ self.bit_end = bit_end
+ self.hex_start = hex_start
+ self.hex_end = hex_end
+ self.byte_start = byte_start
+ self.byte_end = byte_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
@@ -361,7 +385,7 @@ class Generator:
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
) -> str:
- comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
+ comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore
if not comments or isinstance(expression, exp.Binary):
return sql
@@ -510,12 +534,12 @@ class Generator:
position = self.sql(expression, "position")
return f"{position}{this}"
- def columndef_sql(self, expression: exp.ColumnDef) -> str:
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
- kind = f" {kind}" if kind else ""
+ kind = f"{sep}{kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
position = self.sql(expression, "position")
position = f" {position}" if position else ""
@@ -524,7 +548,7 @@ class Generator:
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
- kind_sql = self.sql(expression, "kind")
+ kind_sql = self.sql(expression, "kind").strip()
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
def autoincrementcolumnconstraint_sql(self, _) -> str:
@@ -716,13 +740,22 @@ class Generator:
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
- return self.sql(expression, "this")
+ this = self.sql(expression, "this")
+ if self.bit_start:
+ return f"{self.bit_start}{this}{self.bit_end}"
+ return f"{int(this, 2)}"
def hexstring_sql(self, expression: exp.HexString) -> str:
- return self.sql(expression, "this")
+ this = self.sql(expression, "this")
+ if self.hex_start:
+ return f"{self.hex_start}{this}{self.hex_end}"
+ return f"{int(this, 16)}"
def bytestring_sql(self, expression: exp.ByteString) -> str:
- return self.sql(expression, "this")
+ this = self.sql(expression, "this")
+ if self.byte_start:
+ return f"{self.byte_start}{this}{self.byte_end}"
+ return this
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
@@ -1115,10 +1148,12 @@ class Generator:
return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
- def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
+ def tablesample_sql(
+ self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ ) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
- alias = f" AS {self.sql(expression.this, 'alias')}"
+ alias = f"{sep}{self.sql(expression.this, 'alias')}"
else:
this = self.sql(expression, "this")
alias = ""
@@ -1447,16 +1482,16 @@ class Generator:
)
def select_sql(self, expression: exp.Select) -> str:
- kind = expression.args.get("kind")
- kind = f" AS {kind}" if kind else ""
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
+ kind = expression.args.get("kind")
+ kind = f" AS {kind}" if kind else ""
expressions = self.expressions(expression)
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
- f"SELECT{kind}{hint}{distinct}{expressions}",
+ f"SELECT{hint}{distinct}{kind}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@@ -1475,9 +1510,6 @@ class Generator:
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
- def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
- return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
-
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
@@ -1806,7 +1838,7 @@ class Generator:
return self.binary(expression, op)
sqls = tuple(
- self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
+ self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e)
for i, e in enumerate(expression.flatten(unnest=False))
)
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index e0ddfa2..27de9c7 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -153,7 +153,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
- on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)
for condition in on.flatten():
if isinstance(condition, exp.EQ):
diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py
index 59f3fec..5b2f706 100644
--- a/sqlglot/optimizer/expand_laterals.py
+++ b/sqlglot/optimizer/expand_laterals.py
@@ -29,6 +29,6 @@ def expand_laterals(expression: exp.Expression) -> exp.Expression:
for column in projection.find_all(exp.Column):
if not column.table and column.name in alias_to_expression:
column.replace(alias_to_expression[column.name].copy())
- if isinstance(projection, exp.Alias):
- alias_to_expression[projection.alias] = projection.this
+ if isinstance(projection, exp.Alias):
+ alias_to_expression[projection.alias] = projection.this
return expression
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 40668ef..b013312 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -152,12 +152,14 @@ def _distribute(a, b, from_func, to_func, cache):
lambda c: to_func(
uniq_sort(flatten(from_func(c, b.left)), cache),
uniq_sort(flatten(from_func(c, b.right)), cache),
+ copy=False,
),
)
else:
a = to_func(
uniq_sort(flatten(from_func(a, b.left)), cache),
uniq_sort(flatten(from_func(a, b.right)), cache),
+ copy=False,
)
return a
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 62eb11e..c165ffe 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -10,7 +10,6 @@ from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
-from sqlglot.optimizer.expand_laterals import expand_laterals
from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
@@ -30,7 +29,6 @@ RULES = (
qualify_tables,
isolate_table_selects,
qualify_columns,
- expand_laterals,
pushdown_projections,
validate_qualify_columns,
normalize,
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 0a31246..6ac39f0 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -3,11 +3,12 @@ import typing as t
from sqlglot import alias, exp
from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
-def qualify_columns(expression, schema):
+def qualify_columns(expression, schema, expand_laterals=True):
"""
Rewrite sqlglot AST to have fully qualified columns.
@@ -26,6 +27,9 @@ def qualify_columns(expression, schema):
"""
schema = ensure_schema(schema)
+ if not schema.mapping and expand_laterals:
+ expression = _expand_laterals(expression)
+
for scope in traverse_scope(expression):
resolver = Resolver(scope, schema)
_pop_table_column_aliases(scope.ctes)
@@ -39,6 +43,9 @@ def qualify_columns(expression, schema):
_expand_group_by(scope, resolver)
_expand_order_by(scope)
+ if schema.mapping and expand_laterals:
+ expression = _expand_laterals(expression)
+
return expression
@@ -124,7 +131,7 @@ def _expand_using(scope, resolver):
tables[join_table] = None
join.args.pop("using")
- join.set("on", exp.and_(*conditions))
+ join.set("on", exp.and_(*conditions, copy=False))
if column_tables:
for column in scope.columns:
@@ -240,7 +247,9 @@ def _qualify_columns(scope, resolver):
# column_table can be a '' because bigquery unnest has no table alias
if column_table:
column.set("table", column_table)
- elif column_table not in scope.sources:
+ elif column_table not in scope.sources and (
+ not scope.parent or column_table not in scope.parent.sources
+ ):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@@ -376,10 +385,13 @@ def _qualify_outputs(scope):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
elif not isinstance(selection, exp.Alias) and not selection.is_star:
- alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
- alias_.set("this", selection)
- selection = alias_
-
+ selection = alias(
+ selection,
+ alias=selection.output_name or f"_col_{i}",
+ quoted=True
+ if isinstance(selection, exp.Column) and selection.this.quoted
+ else None,
+ )
if aliased_column:
selection.set("alias", exp.to_identifier(aliased_column))
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index a719ebe..1b451a6 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -7,21 +7,29 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
- Rewrite sqlglot AST to have fully qualified tables.
+ Rewrite sqlglot AST to have fully qualified tables. Additionally, this
+ replaces "join constructs" (*) by equivalent SELECT * subqueries.
- Example:
+ Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
+ >>>
+ >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
+ >>> qualify_tables(expression).sql()
+ 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
Args:
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
schema: A schema to populate
+
Returns:
sqlglot.Expression: qualified expression
+
+ (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
sequence = itertools.count()
@@ -29,6 +37,12 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
+ # Expand join construct
+ if isinstance(derived_table, exp.Subquery):
+ unnested = derived_table.unnest()
+ if isinstance(unnested, exp.Table):
+ derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
+
if not derived_table.args.get("alias"):
alias_ = f"_q_{next(sequence)}"
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index b582eb0..e00b3c9 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -510,6 +510,9 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
+ elif isinstance(scope.expression, exp.Table):
+ # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
+ yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
@@ -587,6 +590,9 @@ def _traverse_tables(scope):
for join in scope.expression.args.get("joins") or []:
expressions.append(join.this)
+ if isinstance(scope.expression, exp.Table):
+ expressions.append(scope.expression)
+
expressions.extend(scope.expression.args.get("laterals") or [])
for expression in expressions:
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 4e6c910..0904189 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
+ copy=False,
)
return expression
@@ -76,9 +77,17 @@ def simplify_not(expression):
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
- return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
+ return exp.or_(
+ exp.not_(condition.left, copy=False),
+ exp.not_(condition.right, copy=False),
+ copy=False,
+ )
if isinstance(condition, exp.Or):
- return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
+ return exp.and_(
+ exp.not_(condition.left, copy=False),
+ exp.not_(condition.right, copy=False),
+ copy=False,
+ )
if is_null(condition):
return exp.null()
if always_true(expression.this):
@@ -254,12 +263,12 @@ def uniq_sort(expression, cache=None, root=True):
# A AND C AND B -> A AND B AND C
for i, (sql, e) in enumerate(arr[1:]):
if sql < arr[i][0]:
- expression = result_func(*(e for _, e in sorted(arr)))
+ expression = result_func(*(e for _, e in sorted(arr)), copy=False)
break
else:
# we didn't have to sort but maybe we need to dedup
if len(deduped) < len(flattened):
- expression = result_func(*deduped.values())
+ expression = result_func(*deduped.values(), copy=False)
return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index abb23ad..d8d9f88 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -126,9 +126,17 @@ class Parser(metaclass=_Parser):
TokenType.BIT,
TokenType.BOOLEAN,
TokenType.TINYINT,
+ TokenType.UTINYINT,
TokenType.SMALLINT,
+ TokenType.USMALLINT,
TokenType.INT,
+ TokenType.UINT,
TokenType.BIGINT,
+ TokenType.UBIGINT,
+ TokenType.INT128,
+ TokenType.UINT128,
+ TokenType.INT256,
+ TokenType.UINT256,
TokenType.FLOAT,
TokenType.DOUBLE,
TokenType.CHAR,
@@ -961,14 +969,15 @@ class Parser(metaclass=_Parser):
The target expression.
"""
instance = exp_class(**kwargs)
- if self._prev_comments:
- instance.comments = self._prev_comments
- self._prev_comments = None
- if comments:
- instance.comments = comments
+ instance.add_comments(comments) if comments else self._add_comments(instance)
self.validate_expression(instance)
return instance
+ def _add_comments(self, expression: t.Optional[exp.Expression]) -> None:
+ if expression and self._prev_comments:
+ expression.add_comments(self._prev_comments)
+ self._prev_comments = None
+
def validate_expression(
self, expression: exp.Expression, args: t.Optional[t.List] = None
) -> None:
@@ -1567,7 +1576,7 @@ class Parser(metaclass=_Parser):
value = self.expression(
exp.Schema,
this="TABLE",
- expressions=self._parse_csv(self._parse_struct_kwargs),
+ expressions=self._parse_csv(self._parse_struct_types),
)
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@@ -1802,14 +1811,15 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.SELECT):
comments = self._prev_comments
+ hint = self._parse_hint()
+ all_ = self._match(TokenType.ALL)
+ distinct = self._match(TokenType.DISTINCT)
+
kind = (
self._match(TokenType.ALIAS)
and self._match_texts(("STRUCT", "VALUE"))
and self._prev.text
)
- hint = self._parse_hint()
- all_ = self._match(TokenType.ALL)
- distinct = self._match(TokenType.DISTINCT)
if distinct:
distinct = self.expression(
@@ -2284,7 +2294,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST):
return None
- expressions = self._parse_wrapped_csv(self._parse_column)
+ expressions = self._parse_wrapped_csv(self._parse_type)
ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
alias = self._parse_table_alias()
@@ -2333,7 +2343,9 @@ class Parser(metaclass=_Parser):
size = None
seed = None
- kind = "TABLESAMPLE" if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
+ kind = (
+ self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
+ )
method = self._parse_var(tokens=(TokenType.ROW,))
self._match(TokenType.L_PAREN)
@@ -2684,7 +2696,7 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.In, this=this, expressions=expressions)
- self._match_r_paren()
+ self._match_r_paren(this)
else:
this = self.expression(exp.In, this=this, field=self._parse_field())
@@ -2798,7 +2810,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.L_PAREN):
if is_struct:
- expressions = self._parse_csv(self._parse_struct_kwargs)
+ expressions = self._parse_csv(self._parse_struct_types)
elif nested:
expressions = self._parse_csv(self._parse_types)
else:
@@ -2833,7 +2845,7 @@ class Parser(metaclass=_Parser):
values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
- expressions = self._parse_csv(self._parse_struct_kwargs)
+ expressions = self._parse_csv(self._parse_struct_types)
else:
expressions = self._parse_csv(self._parse_types)
@@ -2891,16 +2903,10 @@ class Parser(metaclass=_Parser):
prefix=prefix,
)
- def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
- index = self._index
- this = self._parse_id_var()
+ def _parse_struct_types(self) -> t.Optional[exp.Expression]:
+ this = self._parse_type() or self._parse_id_var()
self._match(TokenType.COLON)
- data_type = self._parse_types()
-
- if not data_type:
- self._retreat(index)
- return self._parse_types()
- return self.expression(exp.StructKwarg, this=this, expression=data_type)
+ return self._parse_column_def(this)
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.AT_TIME_ZONE):
@@ -2932,7 +2938,11 @@ class Parser(metaclass=_Parser):
else exp.Literal.string(value)
)
else:
- field = self._parse_star() or self._parse_function() or self._parse_id_var()
+ field = (
+ self._parse_star()
+ or self._parse_function(anonymous=True)
+ or self._parse_id_var()
+ )
if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
@@ -2995,11 +3005,9 @@ class Parser(metaclass=_Parser):
else:
this = self.expression(exp.Paren, this=self._parse_set_operations(this))
- self._match_r_paren()
- comments.extend(self._prev_comments)
-
- if this and comments:
- this.comments = comments
+ if this:
+ this.add_comments(comments)
+ self._match_r_paren(expression=this)
return this
@@ -3017,7 +3025,7 @@ class Parser(metaclass=_Parser):
)
def _parse_function(
- self, functions: t.Optional[t.Dict[str, t.Callable]] = None
+ self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False
) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@@ -3043,7 +3051,7 @@ class Parser(metaclass=_Parser):
parser = self.FUNCTION_PARSERS.get(upper)
- if parser:
+ if parser and not anonymous:
this = parser(self)
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
@@ -3059,7 +3067,7 @@ class Parser(metaclass=_Parser):
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
- if function:
+ if function and not anonymous:
# Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
# second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
if count_params(function) == 2:
@@ -3148,12 +3156,7 @@ class Parser(metaclass=_Parser):
if isinstance(left, exp.Column):
left.replace(exp.Var(this=left.text("this")))
- if self._match(TokenType.IGNORE_NULLS):
- this = self.expression(exp.IgnoreNulls, this=this)
- else:
- self._match(TokenType.RESPECT_NULLS)
-
- return self._parse_limit(self._parse_order(this))
+ return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
@@ -3177,6 +3180,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Schema, this=this, expressions=args)
def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ # column defs are not really columns, they're identifiers
+ if isinstance(this, exp.Column):
+ this = this.this
kind = self._parse_types()
if self._match_text_seq("FOR", "ORDINALITY"):
@@ -3420,7 +3426,7 @@ class Parser(metaclass=_Parser):
elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
self.raise_error("Expected }")
- this.comments = self._prev_comments
+ self._add_comments(this)
return self._parse_bracket(this)
def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
@@ -3584,7 +3590,9 @@ class Parser(metaclass=_Parser):
exp.and_(
exp.Is(this=expression.copy(), expression=exp.Null()),
exp.Is(this=search.copy(), expression=exp.Null()),
+ copy=False,
),
+ copy=False,
)
ifs.append(exp.If(this=cond, true=result))
@@ -3717,15 +3725,15 @@ class Parser(metaclass=_Parser):
if self._match_set(self.TRIM_TYPES):
position = self._prev.text.upper()
- expression = self._parse_term()
+ expression = self._parse_bitwise()
if self._match_set((TokenType.FROM, TokenType.COMMA)):
- this = self._parse_term()
+ this = self._parse_bitwise()
else:
this = expression
expression = None
if self._match(TokenType.COLLATE):
- collation = self._parse_term()
+ collation = self._parse_bitwise()
return self.expression(
exp.Trim,
@@ -3741,6 +3749,15 @@ class Parser(metaclass=_Parser):
def _parse_named_window(self) -> t.Optional[exp.Expression]:
return self._parse_window(self._parse_id_var(), alias=True)
+ def _parse_respect_or_ignore_nulls(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ if self._match(TokenType.IGNORE_NULLS):
+ return self.expression(exp.IgnoreNulls, this=this)
+ if self._match(TokenType.RESPECT_NULLS):
+ return self.expression(exp.RespectNulls, this=this)
+ return this
+
def _parse_window(
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:
@@ -3768,10 +3785,7 @@ class Parser(metaclass=_Parser):
# (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html)
# and Snowflake chose to do the same for familiarity
# https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes
- if self._match(TokenType.IGNORE_NULLS):
- this = self.expression(exp.IgnoreNulls, this=this)
- elif self._match(TokenType.RESPECT_NULLS):
- this = self.expression(exp.RespectNulls, this=this)
+ this = self._parse_respect_or_ignore_nulls(this)
# bigquery select from window x AS (partition by ...)
if alias:
@@ -3975,9 +3989,7 @@ class Parser(metaclass=_Parser):
items = [parse_result] if parse_result is not None else []
while self._match(sep):
- if parse_result and self._prev_comments:
- parse_result.comments = self._prev_comments
-
+ self._add_comments(parse_result)
parse_result = parse_method()
if parse_result is not None:
items.append(parse_result)
@@ -4345,13 +4357,14 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
- def _match(self, token_type, advance=True):
+ def _match(self, token_type, advance=True, expression=None):
if not self._curr:
return None
if self._curr.token_type == token_type:
if advance:
self._advance()
+ self._add_comments(expression)
return True
return None
@@ -4379,16 +4392,12 @@ class Parser(metaclass=_Parser):
return None
def _match_l_paren(self, expression=None):
- if not self._match(TokenType.L_PAREN):
+ if not self._match(TokenType.L_PAREN, expression=expression):
self.raise_error("Expecting (")
- if expression and self._prev_comments:
- expression.comments = self._prev_comments
def _match_r_paren(self, expression=None):
- if not self._match(TokenType.R_PAREN):
+ if not self._match(TokenType.R_PAREN, expression=expression):
self.raise_error("Expecting )")
- if expression and self._prev_comments:
- expression.comments = self._prev_comments
def _match_texts(self, texts, advance=True):
if self._curr and self._curr.text.upper() in texts:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 64c1f92..5e50b7c 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -84,6 +84,10 @@ class TokenType(AutoName):
UINT = auto()
BIGINT = auto()
UBIGINT = auto()
+ INT128 = auto()
+ UINT128 = auto()
+ INT256 = auto()
+ UINT256 = auto()
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
@@ -774,8 +778,6 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
- "_prev_token_comments",
- "_prev_token_type",
)
def __init__(self) -> None:
@@ -795,8 +797,6 @@ class Tokenizer(metaclass=_Tokenizer):
self._end = False
self._peek = ""
self._prev_token_line = -1
- self._prev_token_comments: t.List[str] = []
- self._prev_token_type: t.Optional[TokenType] = None
def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
@@ -846,7 +846,7 @@ class Tokenizer(metaclass=_Tokenizer):
return self.sql[start:end]
return ""
- def _advance(self, i: int = 1) -> None:
+ def _advance(self, i: int = 1, alnum: bool = False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._col = 1
self._line += 1
@@ -858,14 +858,30 @@ class Tokenizer(metaclass=_Tokenizer):
self._char = self.sql[self._current - 1]
self._peek = "" if self._end else self.sql[self._current]
+ if alnum and self._char.isalnum():
+ _col = self._col
+ _current = self._current
+ _end = self._end
+ _peek = self._peek
+
+ while _peek.isalnum():
+ _col += 1
+ _current += 1
+ _end = _current >= self.size
+ _peek = "" if _end else self.sql[_current]
+
+ self._col = _col
+ self._current = _current
+ self._end = _end
+ self._peek = _peek
+ self._char = self.sql[_current - 1]
+
@property
def _text(self) -> str:
return self.sql[self._start : self._current]
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
- self._prev_token_comments = self._comments
- self._prev_token_type = token_type
self.tokens.append(
Token(
token_type,
@@ -966,13 +982,13 @@ class Tokenizer(metaclass=_Tokenizer):
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
- self._advance()
+ self._advance(alnum=True)
self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
- self._advance()
+ self._advance(alnum=True)
self._comments.append(self._text[comment_start_size:])
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
@@ -988,9 +1004,9 @@ class Tokenizer(metaclass=_Tokenizer):
if self._char == "0":
peek = self._peek.upper()
if peek == "B":
- return self._scan_bits()
+ return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
elif peek == "X":
- return self._scan_hex()
+ return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)
decimal = False
scientific = 0
@@ -1033,7 +1049,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
value = self._extract_value()
try:
- self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
+ # If `value` can't be converted to a binary, fallback to tokenizing it as an identifier
+ int(value, 2)
+ self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b
except ValueError:
self._add(TokenType.IDENTIFIER)
@@ -1041,7 +1059,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
value = self._extract_value()
try:
- self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
+ # If `value` can't be converted to a hex, fallback to tokenizing it as an identifier
+ int(value, 16)
+ self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x
except ValueError:
self._add(TokenType.IDENTIFIER)
@@ -1049,7 +1069,7 @@ class Tokenizer(metaclass=_Tokenizer):
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
- self._advance()
+ self._advance(alnum=True)
else:
break
@@ -1066,7 +1086,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True
- # X'1234, b'0110', E'\\\\\' etc.
+ # X'1234', b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
@@ -1087,60 +1107,43 @@ class Tokenizer(metaclass=_Tokenizer):
string_end = delimiters[string_start]
text = self._extract_string(string_end)
- if base is None:
- self._add(token_type, text)
- else:
+ if base:
try:
- self._add(token_type, f"{int(text, base)}")
+ int(text, base)
except:
raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)
+ self._add(token_type, text)
return True
def _scan_identifier(self, identifier_end: str) -> None:
- text = ""
- identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES
-
- while True:
- if self._end:
- raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
-
- self._advance()
- if self._char == identifier_end:
- if identifier_end_is_escape and self._peek == identifier_end:
- text += identifier_end
- self._advance()
- continue
-
- break
-
- text += self._char
-
+ self._advance()
+ text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
self._add(TokenType.IDENTIFIER, text)
def _scan_var(self) -> None:
while True:
char = self._peek.strip()
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
- self._advance()
+ self._advance(alnum=True)
else:
break
+
self._add(
TokenType.VAR
- if self._prev_token_type == TokenType.PARAMETER
+ if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
- def _extract_string(self, delimiter: str) -> str:
+ def _extract_string(self, delimiter: str, escapes=None) -> str:
text = ""
delim_size = len(delimiter)
+ escapes = self._STRING_ESCAPES if escapes is None else escapes
while True:
- if self._char in self._STRING_ESCAPES and (
- self._peek == delimiter or self._peek in self._STRING_ESCAPES
- ):
+ if self._char in escapes and (self._peek == delimiter or self._peek in escapes):
if self._peek == delimiter:
text += self._peek
else:
@@ -1158,7 +1161,9 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
- text += self._char
- self._advance()
+
+ current = self._current - 1
+ self._advance(alnum=True)
+ text += self.sql[current : self._current - 1]
return text
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 00f278e..3643cd7 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -121,20 +121,9 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
other expressions. This transforms removes the precision from parameterized types in expressions.
"""
- return expression.transform(
- lambda node: exp.DataType(
- **{
- **node.args,
- "expressions": [
- node_expression
- for node_expression in node.expressions
- if isinstance(node_expression, exp.DataType)
- ],
- }
- )
- if isinstance(node, exp.DataType)
- else node,
- )
+ for node in expression.find_all(exp.DataType):
+ node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
+ return expression
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
@@ -240,12 +229,36 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
return expression
+def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
+ if (
+ isinstance(expression, exp.WithinGroup)
+ and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
+ and isinstance(expression.expression, exp.Order)
+ ):
+ quantile = expression.this.this
+ input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
+ return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
+
+ return expression
+
+
+def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Pivot):
+ expression.args["field"].transform(
+ lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
+ copy=False,
+ )
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
- expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
+ expression to SQL, using either the "_sql" method corresponding to the resulting expression,
+ or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Args:
transforms: sequence of transform functions. These will be called in order.
@@ -255,17 +268,28 @@ def preprocess(
"""
def _to_sql(self, expression: exp.Expression) -> str:
+ expression_type = type(expression)
+
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
- return getattr(self, expression.key + "_sql")(expression)
- return _to_sql
+ _sql_handler = getattr(self, expression.key + "_sql", None)
+ if _sql_handler:
+ return _sql_handler(expression)
+
+ transforms_handler = self.TRANSFORMS.get(type(expression))
+ if transforms_handler:
+ # Ensures we don't enter an infinite loop. This can happen when the original expression
+ # has the same type as the final expression and there's no _sql method available for it,
+ # because then it'd re-enter _to_sql.
+ if expression_type is type(expression):
+ raise ValueError(
+ f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
+ )
+ return transforms_handler(self, expression)
-UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
-ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
-ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
-REMOVE_PRECISION_PARAMETERIZED_TYPES = {
- exp.Cast: preprocess([remove_precision_parameterized_types])
-}
+ raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
+
+ return _to_sql