summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py51
-rw-r--r--sqlglot/dialects/clickhouse.py20
-rw-r--r--sqlglot/dialects/databricks.py11
-rw-r--r--sqlglot/dialects/dialect.py42
-rw-r--r--sqlglot/dialects/drill.py3
-rw-r--r--sqlglot/dialects/duckdb.py66
-rw-r--r--sqlglot/dialects/hive.py45
-rw-r--r--sqlglot/dialects/mysql.py3
-rw-r--r--sqlglot/dialects/oracle.py13
-rw-r--r--sqlglot/dialects/postgres.py18
-rw-r--r--sqlglot/dialects/presto.py30
-rw-r--r--sqlglot/dialects/redshift.py4
-rw-r--r--sqlglot/dialects/snowflake.py4
-rw-r--r--sqlglot/dialects/spark.py240
-rw-r--r--sqlglot/dialects/spark2.py238
-rw-r--r--sqlglot/dialects/sqlite.py43
-rw-r--r--sqlglot/dialects/starrocks.py1
-rw-r--r--sqlglot/dialects/tableau.py3
-rw-r--r--sqlglot/dialects/teradata.py3
-rw-r--r--sqlglot/dialects/tsql.py11
21 files changed, 536 insertions, 314 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 191e703..fc34262 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
+from sqlglot.dialects.spark2 import Spark2
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 1a88654..9705b35 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -39,18 +39,26 @@ def _date_add_sql(
def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
if not isinstance(expression.unnest().parent, exp.From):
- expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression))
return self.values_sql(expression)
- rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
- structs = []
- for row in rows:
- aliases = [
- exp.alias_(value, column_name)
- for value, column_name in zip(row, expression.args["alias"].args["columns"])
- ]
- structs.append(exp.Struct(expressions=aliases))
- unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
- return self.unnest_sql(unnest_exp)
+
+ alias = expression.args.get("alias")
+
+ structs = [
+ exp.Struct(
+ expressions=[
+ exp.alias_(value, column_name)
+ for value, column_name in zip(
+ t.expressions,
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(t.expressions))),
+ )
+ ]
+ )
+ for t in expression.find_all(exp.Tuple)
+ ]
+
+ return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
@@ -128,6 +136,7 @@ class BigQuery(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
+ BYTE_STRINGS = [("b'", "'"), ("B'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -139,6 +148,7 @@ class BigQuery(Dialect):
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
+ "BYTES": TokenType.BINARY,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
}
@@ -153,7 +163,7 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"DATE_TRUNC": lambda args: exp.DateTrunc(
- unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore
+ unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
@@ -206,6 +216,12 @@ class BigQuery(Dialect):
"NOT DETERMINISTIC": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("VOLATILE")
),
+ "OPTIONS": lambda self: self._parse_with_property(),
+ }
+
+ CONSTRAINT_PARSERS = {
+ **parser.Parser.CONSTRAINT_PARSERS, # type: ignore
+ "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()),
}
class Generator(generator.Generator):
@@ -217,11 +233,11 @@ class BigQuery(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.AtTimeZone: lambda self, e: self.func(
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
+ exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
@@ -234,7 +250,9 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
- exp.Select: transforms.preprocess([_unqualify_unnest]),
+ exp.Select: transforms.preprocess(
+ [_unqualify_unnest, transforms.eliminate_distinct_on]
+ ),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@@ -259,6 +277,7 @@ class BigQuery(Dialect):
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.BINARY: "BYTES",
exp.DataType.Type.BOOLEAN: "BOOL",
exp.DataType.Type.CHAR: "STRING",
exp.DataType.Type.DECIMAL: "NUMERIC",
@@ -272,6 +291,7 @@ class BigQuery(Dialect):
exp.DataType.Type.TIMESTAMP: "DATETIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.TINYINT: "INT64",
+ exp.DataType.Type.VARBINARY: "BYTES",
exp.DataType.Type.VARCHAR: "STRING",
exp.DataType.Type.VARIANT: "ANY TYPE",
}
@@ -310,3 +330,6 @@ class BigQuery(Dialect):
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("OPTIONS"))
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index e91b0bf..2a49066 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -22,6 +22,8 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
+ BIT_STRINGS = [("0b", "")]
+ HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -31,10 +33,18 @@ class ClickHouse(Dialect):
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
+ "INT8": TokenType.TINYINT,
+ "UINT8": TokenType.UTINYINT,
"INT16": TokenType.SMALLINT,
+ "UINT16": TokenType.USMALLINT,
"INT32": TokenType.INT,
+ "UINT32": TokenType.UINT,
"INT64": TokenType.BIGINT,
- "INT8": TokenType.TINYINT,
+ "UINT64": TokenType.UBIGINT,
+ "INT128": TokenType.INT128,
+ "UINT128": TokenType.UINT128,
+ "INT256": TokenType.INT256,
+ "UINT256": TokenType.UINT256,
"TUPLE": TokenType.STRUCT,
}
@@ -121,9 +131,17 @@ class ClickHouse(Dialect):
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.STRUCT: "Tuple",
exp.DataType.Type.TINYINT: "Int8",
+ exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.SMALLINT: "Int16",
+ exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.INT: "Int32",
+ exp.DataType.Type.UINT: "UInt32",
exp.DataType.Type.BIGINT: "Int64",
+ exp.DataType.Type.UBIGINT: "UInt64",
+ exp.DataType.Type.INT128: "Int128",
+ exp.DataType.Type.UINT128: "UInt128",
+ exp.DataType.Type.INT256: "Int256",
+ exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.FLOAT: "Float32",
exp.DataType.Type.DOUBLE: "Float64",
}
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 138f26c..51112a0 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp
+from sqlglot import exp, transforms
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
@@ -29,13 +29,20 @@ class Databricks(Spark):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
+ exp.Select: transforms.preprocess(
+ [
+ transforms.eliminate_distinct_on,
+ transforms.unnest_to_explode,
+ ]
+ ),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
- TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
PARAMETER_TOKEN = "$"
class Tokenizer(Spark.Tokenizer):
+ HEX_STRINGS = []
+
SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 19c6f73..71269f2 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -28,6 +28,7 @@ class Dialects(str, Enum):
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
+ SPARK2 = "spark2"
SQLITE = "sqlite"
STARROCKS = "starrocks"
TABLEAU = "tableau"
@@ -69,30 +70,17 @@ class _Dialect(type):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]
- if (
- klass.tokenizer_class._BIT_STRINGS
- and exp.BitString not in klass.generator_class.TRANSFORMS
- ):
- bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
- klass.generator_class.TRANSFORMS[
- exp.BitString
- ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
- if (
- klass.tokenizer_class._HEX_STRINGS
- and exp.HexString not in klass.generator_class.TRANSFORMS
- ):
- hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
- klass.generator_class.TRANSFORMS[
- exp.HexString
- ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
- if (
- klass.tokenizer_class._BYTE_STRINGS
- and exp.ByteString not in klass.generator_class.TRANSFORMS
- ):
- be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
- klass.generator_class.TRANSFORMS[
- exp.ByteString
- ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
+ klass.bit_start, klass.bit_end = seq_get(
+ list(klass.tokenizer_class._BIT_STRINGS.items()), 0
+ ) or (None, None)
+
+ klass.hex_start, klass.hex_end = seq_get(
+ list(klass.tokenizer_class._HEX_STRINGS.items()), 0
+ ) or (None, None)
+
+ klass.byte_start, klass.byte_end = seq_get(
+ list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
+ ) or (None, None)
return klass
@@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect):
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
+ "bit_start": self.bit_start,
+ "bit_end": self.bit_end,
+ "hex_start": self.hex_start,
+ "hex_end": self.hex_end,
+ "byte_start": self.byte_start,
+ "byte_end": self.byte_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index d7e2d88..7ad555e 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
create_with_partitions_sql,
@@ -145,6 +145,7 @@ class Drill(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: _str_to_date,
exp.Pow: rename_func("POW"),
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 9454db6..bce956e 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@@ -23,52 +25,61 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _ts_or_ds_add(self, expression):
+def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
-def _date_add(self, expression):
+def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
-def _array_sort_sql(self, expression):
+def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
-def _sort_array_sql(self, expression):
+def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
-def _sort_array_reverse(args):
+def _sort_array_reverse(args: t.Sequence) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
-def _struct_sql(self, expression):
+def _parse_date_diff(args: t.Sequence) -> exp.Expression:
+ return exp.DateDiff(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
+ )
+
+
+def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
-def _datatype_sql(self, expression):
+def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
if expression.this == exp.DataType.Type.ARRAY:
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
-def _regexp_extract_sql(self, expression):
+def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str:
bad_args = list(filter(expression.args.get, ("position", "occurrence")))
if bad_args:
self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}")
+
return self.func(
"REGEXP_EXTRACT",
expression.args.get("this"),
@@ -108,6 +119,8 @@ class DuckDB(Dialect):
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _sort_array_reverse,
+ "DATEDIFF": _parse_date_diff,
+ "DATE_DIFF": _parse_date_diff,
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
@@ -115,18 +128,18 @@ class DuckDB(Dialect):
expression=exp.Literal.number(1000),
)
),
- "LIST_SORT": exp.SortArray.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
+ "LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
- "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
- "STR_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT": exp.Split.from_arg_list,
- "STRING_TO_ARRAY": exp.Split.from_arg_list,
- "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
+ "STRING_TO_ARRAY": exp.Split.from_arg_list,
+ "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
"STRUCT_PACK": exp.Struct.from_arg_list,
+ "STR_SPLIT": exp.Split.from_arg_list,
+ "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
}
@@ -142,10 +155,11 @@ class DuckDB(Dialect):
class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
+ LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if isinstance(seq_get(e.expressions, 0), exp.Select)
@@ -154,13 +168,16 @@ class DuckDB(Dialect):
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
+ exp.CurrentDate: lambda self, e: "CURRENT_DATE",
+ exp.CurrentTime: lambda self, e: "CURRENT_TIME",
+ exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
- exp.DateAdd: _date_add,
+ exp.DateAdd: _date_add_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this
+ "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
@@ -192,7 +209,7 @@ class DuckDB(Dialect):
exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
- exp.TsOrDsAdd: _ts_or_ds_add,
+ exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
@@ -201,7 +218,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.CHAR: "TEXT",
exp.DataType.Type.FLOAT: "REAL",
@@ -212,17 +229,14 @@ class DuckDB(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
}
- STAR_MAPPING = {
- **generator.Generator.STAR_MAPPING,
- "except": "EXCLUDE",
- }
+ STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- LIMIT_FETCH = "LIMIT"
-
- def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str:
- return super().tablesample_sql(expression, seed_prefix="REPEATABLE")
+ def tablesample_sql(
+ self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ ) -> str:
+ return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 6746fcf..871a180 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
-def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
+def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
+ this = expression.this
+
+ if not this.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ annotate_types(this)
+
+ if this.type.is_type(exp.DataType.Type.JSON):
+ return self.sql(this)
+ return self.func("TO_JSON", this, expression.args.get("options"))
+
+
+def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
@@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
-def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
+def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
-def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
return f"CAST({this} AS DATE)"
-def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
+def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -214,6 +227,7 @@ class Hive(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
@@ -251,6 +265,7 @@ class Hive(Dialect):
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
+ "UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@@ -280,16 +295,20 @@ class Hive(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
- **transforms.ELIMINATE_QUALIFY, # type: ignore
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
- [transforms.eliminate_qualify, transforms.unnest_to_explode]
+ [
+ transforms.eliminate_qualify,
+ transforms.eliminate_distinct_on,
+ transforms.unnest_to_explode,
+ ]
),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
+ exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
- exp.ArraySort: _array_sort,
+ exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
@@ -298,12 +317,13 @@ class Hive(Dialect):
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
+ exp.FromBase64: rename_func("UNBASE64"),
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
- exp.JSONFormat: rename_func("TO_JSON"),
+ exp.JSONFormat: _json_format_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@@ -318,9 +338,9 @@ class Hive(Dialect):
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
exp.StrPosition: strposition_to_locate_sql,
- exp.StrToDate: _str_to_date,
- exp.StrToTime: _str_to_time,
- exp.StrToUnix: _str_to_unix,
+ exp.StrToDate: _str_to_date_sql,
+ exp.StrToTime: _str_to_time_sql,
+ exp.StrToUnix: _str_to_unix_sql,
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}",
exp.TimeStrToDate: rename_func("TO_DATE"),
@@ -328,6 +348,7 @@ class Hive(Dialect):
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: _time_to_str,
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.TsOrDsToDate: _to_date_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 666e740..5342624 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -403,6 +403,7 @@ class MySQL(Dialect):
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 9ccd02e..c8af1c6 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable:
class Oracle(Dialect):
+ alias_post_tablesample = True
+
# https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212
# https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes
time_mapping = {
@@ -121,21 +123,23 @@ class Oracle(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
exp.DateStrToDate: lambda self, e: self.func(
"TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD")
),
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
+ exp.IfNull: rename_func("NVL"),
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
+ exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
- exp.IfNull: rename_func("NVL"),
}
PROPERTIES_LOCATION = {
@@ -164,14 +168,19 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
+ VAR_SINGLE_TOKENS = {"@"}
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"(+)": TokenType.JOIN_MARKER,
+ "BINARY_DOUBLE": TokenType.DOUBLE,
+ "BINARY_FLOAT": TokenType.FLOAT,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
"RETURNING": TokenType.RETURNING,
+ "SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c47ff51..2132778 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -1,6 +1,8 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, tokens
+import typing as t
+
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import (
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
from sqlglot.tokens import TokenType
-from sqlglot.transforms import preprocess, remove_target_from_merge
DATE_DIFF_FACTOR = {
"MICROSECOND": " * 1000000",
@@ -274,8 +275,7 @@ class Postgres(Dialect):
TokenType.HASH: exp.BitwiseXor,
}
- FACTOR = {
- **parser.Parser.FACTOR,
+ EXPONENT = {
TokenType.CARET: exp.Pow,
}
@@ -286,6 +286,12 @@ class Postgres(Dialect):
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
+ def _parse_factor(self) -> t.Optional[exp.Expression]:
+ return self._parse_tokens(self._parse_exponent, self.FACTOR)
+
+ def _parse_exponent(self) -> t.Optional[exp.Expression]:
+ return self._parse_tokens(self._parse_unary, self.EXPONENT)
+
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
self._match(TokenType.COMMA)
@@ -316,7 +322,7 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
- exp.ColumnDef: preprocess(
+ exp.ColumnDef: transforms.preprocess(
[
_auto_increment_to_serial,
_serial_to_generated,
@@ -341,7 +347,7 @@ class Postgres(Dialect):
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
- exp.Merge: preprocess([remove_target_from_merge]),
+ exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 489d439..6133a27 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
- step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
+ step = expression.args.get("step")
target_type = None
@@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
else:
start = exp.Cast(this=start, to=to)
- return self.func("SEQUENCE", start, end, step)
+ sql = self.func("SEQUENCE", start, end, step)
+ if isinstance(expression.parent, exp.Table):
+ sql = f"UNNEST({sql})"
+
+ return sql
def _ensure_utf8(charset: exp.Literal) -> None:
@@ -204,6 +208,7 @@ class Presto(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
@@ -219,23 +224,23 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
+ "FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
+ ),
"NOW": exp.CurrentTimestamp.from_arg_list,
+ "SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
- "APPROX_PERCENTILE": _approx_percentile,
- "FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
- "FROM_UTF8": lambda args: exp.Decode(
- this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
- ),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
@@ -264,7 +269,6 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@@ -290,6 +294,7 @@ class Presto(Dialect):
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
exp.Encode: _encode_sql,
exp.GenerateSeries: _sequence_sql,
+ exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
@@ -303,7 +308,11 @@ class Presto(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_qualify, transforms.explode_to_unnest]
+ [
+ transforms.eliminate_qualify,
+ transforms.eliminate_distinct_on,
+ transforms.explode_to_unnest,
+ ]
),
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
@@ -327,6 +336,9 @@ class Presto(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
exp.VariancePop: rename_func("VAR_POP"),
+ exp.WithinGroup: transforms.preprocess(
+ [transforms.remove_within_group_for_percentiles]
+ ),
}
def interval_sql(self, expression: exp.Interval) -> str:
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index a9c4f62..1b7cf31 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -52,6 +52,8 @@ class Redshift(Postgres):
return this
class Tokenizer(Postgres.Tokenizer):
+ BIT_STRINGS = []
+ HEX_STRINGS = []
STRING_ESCAPES = ["\\"]
KEYWORDS = {
@@ -90,7 +92,6 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
- **transforms.ELIMINATE_DISTINCT_ON, # type: ignore
exp.CurrentTimestamp: lambda self, e: "SYSDATE",
exp.DateAdd: lambda self, e: self.func(
"DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
@@ -102,6 +103,7 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 0829669..70dcaa9 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
date_trunc_to_time,
@@ -252,6 +252,7 @@ class Snowflake(Dialect):
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -305,6 +306,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index a3e4cce..939f2fd 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -2,222 +2,54 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parser
-from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
-from sqlglot.dialects.hive import Hive
+from sqlglot import exp
+from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
-def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
- kind = e.args["kind"]
- properties = e.args.get("properties")
+def _parse_datediff(args: t.Sequence) -> exp.Expression:
+ """
+ Although Spark docs don't mention the "unit" argument, Spark3 added support for
+ it at some point. Databricks also supports this variation (see below).
- if kind.upper() == "TABLE" and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- ):
- return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
- return create_with_partitions_sql(self, e)
+ For example, in spark-sql (v3.3.1):
+ - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
+ - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
+ See also:
+ - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
+ - https://docs.databricks.com/sql/language-manual/functions/datediff.html
+ """
+ unit = None
+ this = seq_get(args, 0)
+ expression = seq_get(args, 1)
-def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
- keys = self.sql(expression.args["keys"])
- values = self.sql(expression.args["values"])
- return f"MAP_FROM_ARRAYS({keys}, {values})"
+ if len(args) == 3:
+ unit = this
+ this = args[2]
+ return exp.DateDiff(
+ this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
+ )
-def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
- this = self.sql(expression, "this")
- time_format = self.format_time(expression)
- if time_format == Hive.date_format:
- return f"TO_DATE({this})"
- return f"TO_DATE({this}, {time_format})"
-
-def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
- scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
- if scale is None:
- return f"FROM_UNIXTIME({timestamp})"
- if scale == exp.UnixToTime.SECONDS:
- return f"TIMESTAMP_SECONDS({timestamp})"
- if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
- if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
-
- raise ValueError("Improper scale for timestamp")
-
-
-class Spark(Hive):
- class Parser(Hive.Parser):
+class Spark(Spark2):
+ class Parser(Spark2.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS, # type: ignore
- "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
- "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
- "LEFT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Literal.number(1),
- length=seq_get(args, 1),
- ),
- "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "RIGHT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Sub(
- this=exp.Length(this=seq_get(args, 0)),
- expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
- ),
- length=seq_get(args, 1),
- ),
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "BOOLEAN": lambda args: exp.Cast(
- this=seq_get(args, 0), to=exp.DataType.build("boolean")
- ),
- "IIF": exp.If.from_arg_list,
- "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")),
- "AGGREGATE": exp.Reduce.from_arg_list,
- "DAYOFWEEK": lambda args: exp.DayOfWeek(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFYEAR": lambda args: exp.DayOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "WEEKOFYEAR": lambda args: exp.WeekOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1),
- unit=exp.var(seq_get(args, 0)),
- ),
- "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")),
- "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
- "TIMESTAMP": lambda args: exp.Cast(
- this=seq_get(args, 0), to=exp.DataType.build("timestamp")
- ),
- }
-
- FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS, # type: ignore
- "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
- "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
- "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
- "MERGE": lambda self: self._parse_join_hint("MERGE"),
- "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
- "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
- "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
- "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
- }
-
- def _parse_add_column(self) -> t.Optional[exp.Expression]:
- return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
-
- def _parse_drop_column(self) -> t.Optional[exp.Expression]:
- return self._match_text_seq("DROP", "COLUMNS") and self.expression(
- exp.Drop,
- this=self._parse_schema(),
- kind="COLUMNS",
- )
-
- def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
- # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
- if len(pivot_columns) == 1:
- return [""]
-
- names = []
- for agg in pivot_columns:
- if isinstance(agg, exp.Alias):
- names.append(agg.alias)
- else:
- """
- This case corresponds to aggregations without aliases being used as suffixes
- (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
- be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
- Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
-
- Moreover, function names are lowercased in order to mimic Spark's naming scheme.
- """
- agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
- )
- names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
-
- return names
-
- class Generator(Hive.Generator):
- TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING, # type: ignore
- exp.DataType.Type.TINYINT: "BYTE",
- exp.DataType.Type.SMALLINT: "SHORT",
- exp.DataType.Type.BIGINT: "LONG",
- }
-
- PROPERTIES_LOCATION = {
- **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
- exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
- exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
- exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
- }
-
- TRANSFORMS = {
- **Hive.Generator.TRANSFORMS, # type: ignore
- exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
- exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
- exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
- exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
- exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
- exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
- exp.StrToDate: _str_to_date,
- exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time_sql,
- exp.Create: _create_sql,
- exp.Map: _map_sql,
- exp.Reduce: rename_func("AGGREGATE"),
- exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
- exp.Trim: trim_sql,
- exp.VariancePop: rename_func("VAR_POP"),
- exp.DateFromParts: rename_func("MAKE_DATE"),
- exp.LogicalOr: rename_func("BOOL_OR"),
- exp.LogicalAnd: rename_func("BOOL_AND"),
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.DayOfMonth: rename_func("DAYOFMONTH"),
- exp.DayOfYear: rename_func("DAYOFYEAR"),
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
- exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ **Spark2.Parser.FUNCTIONS, # type: ignore
+ "DATEDIFF": _parse_datediff,
}
- TRANSFORMS.pop(exp.ArraySort)
- TRANSFORMS.pop(exp.ILike)
- WRAP_DERIVED_VALUES = False
- CREATE_FUNCTION_RETURN_AS = False
+ class Generator(Spark2.Generator):
+ TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
+ TRANSFORMS.pop(exp.DateDiff)
- def cast_sql(self, expression: exp.Cast) -> str:
- if isinstance(expression.this, exp.Cast) and expression.this.is_type(
- exp.DataType.Type.JSON
- ):
- schema = f"'{self.sql(expression, 'to')}'"
- return self.func("FROM_JSON", expression.this.this, schema)
- if expression.to.is_type(exp.DataType.Type.JSON):
- return self.func("TO_JSON", expression.this)
+ def datediff_sql(self, expression: exp.DateDiff) -> str:
+ unit = self.sql(expression, "unit")
+ end = self.sql(expression, "this")
+ start = self.sql(expression, "expression")
- return super(Spark.Generator, self).cast_sql(expression)
+ if unit:
+ return self.func("DATEDIFF", unit, start, end)
- class Tokenizer(Hive.Tokenizer):
- HEX_STRINGS = [("X'", "'")]
+ return self.func("DATEDIFF", end, start)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
new file mode 100644
index 0000000..584671f
--- /dev/null
+++ b/sqlglot/dialects/spark2.py
@@ -0,0 +1,238 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp, parser, transforms
+from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql
+from sqlglot.dialects.hive import Hive
+from sqlglot.helper import seq_get
+
+
+def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
+ kind = e.args["kind"]
+ properties = e.args.get("properties")
+
+ if kind.upper() == "TABLE" and any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ ):
+ return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
+ return create_with_partitions_sql(self, e)
+
+
+def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
+ keys = self.sql(expression.args["keys"])
+ values = self.sql(expression.args["values"])
+ return f"MAP_FROM_ARRAYS({keys}, {values})"
+
+
+def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]:
+ return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
+
+
+def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format == Hive.date_format:
+ return f"TO_DATE({this})"
+ return f"TO_DATE({this}, {time_format})"
+
+
+def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale is None:
+ return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
+ if scale == exp.UnixToTime.SECONDS:
+ return f"TIMESTAMP_SECONDS({timestamp})"
+ if scale == exp.UnixToTime.MILLIS:
+ return f"TIMESTAMP_MILLIS({timestamp})"
+ if scale == exp.UnixToTime.MICROS:
+ return f"TIMESTAMP_MICROS({timestamp})"
+
+ raise ValueError("Improper scale for timestamp")
+
+
+class Spark2(Hive):
+ class Parser(Hive.Parser):
+ FUNCTIONS = {
+ **Hive.Parser.FUNCTIONS, # type: ignore
+ "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
+ "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "LEFT": lambda args: exp.Substring(
+ this=seq_get(args, 0),
+ start=exp.Literal.number(1),
+ length=seq_get(args, 1),
+ ),
+ "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ ),
+ "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ ),
+ "RIGHT": lambda args: exp.Substring(
+ this=seq_get(args, 0),
+ start=exp.Sub(
+ this=exp.Length(this=seq_get(args, 0)),
+ expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
+ ),
+ length=seq_get(args, 1),
+ ),
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "IIF": exp.If.from_arg_list,
+ "AGGREGATE": exp.Reduce.from_arg_list,
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ ),
+ "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")),
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1),
+ unit=exp.var(seq_get(args, 0)),
+ ),
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "BOOLEAN": _parse_as_cast("boolean"),
+ "DOUBLE": _parse_as_cast("double"),
+ "FLOAT": _parse_as_cast("float"),
+ "INT": _parse_as_cast("int"),
+ "STRING": _parse_as_cast("string"),
+ "TIMESTAMP": _parse_as_cast("timestamp"),
+ }
+
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
+ "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
+ "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
+ "MERGE": lambda self: self._parse_join_hint("MERGE"),
+ "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"),
+ "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"),
+ "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"),
+ "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"),
+ }
+
+ def _parse_add_column(self) -> t.Optional[exp.Expression]:
+ return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema()
+
+ def _parse_drop_column(self) -> t.Optional[exp.Expression]:
+ return self._match_text_seq("DROP", "COLUMNS") and self.expression(
+ exp.Drop,
+ this=self._parse_schema(),
+ kind="COLUMNS",
+ )
+
+ def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]:
+ # Spark doesn't add a suffix to the pivot columns when there's a single aggregation
+ if len(pivot_columns) == 1:
+ return [""]
+
+ names = []
+ for agg in pivot_columns:
+ if isinstance(agg, exp.Alias):
+ names.append(agg.alias)
+ else:
+ """
+ This case corresponds to aggregations without aliases being used as suffixes
+ (e.g. col_avg(foo)). We need to unquote identifiers because they're going to
+ be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`.
+ Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
+
+ Moreover, function names are lowercased in order to mimic Spark's naming scheme.
+ """
+ agg_all_unquoted = agg.transform(
+ lambda node: exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
+ names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower"))
+
+ return names
+
+ class Generator(Hive.Generator):
+ TYPE_MAPPING = {
+ **Hive.Generator.TYPE_MAPPING, # type: ignore
+ exp.DataType.Type.TINYINT: "BYTE",
+ exp.DataType.Type.SMALLINT: "SHORT",
+ exp.DataType.Type.BIGINT: "LONG",
+ }
+
+ PROPERTIES_LOCATION = {
+ **Hive.Generator.PROPERTIES_LOCATION, # type: ignore
+ exp.EngineProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.CollateProperty: exp.Properties.Location.UNSUPPORTED,
+ }
+
+ TRANSFORMS = {
+ **Hive.Generator.TRANSFORMS, # type: ignore
+ exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
+ exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
+ exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
+ exp.Create: _create_sql,
+ exp.DateFromParts: rename_func("MAKE_DATE"),
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
+ exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
+ exp.LogicalAnd: rename_func("BOOL_AND"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.Map: _map_sql,
+ exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]),
+ exp.Reduce: rename_func("AGGREGATE"),
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimestampTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
+ ),
+ exp.Trim: trim_sql,
+ exp.UnixToTime: _unix_to_time_sql,
+ exp.VariancePop: rename_func("VAR_POP"),
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
+ exp.WithinGroup: transforms.preprocess(
+ [transforms.remove_within_group_for_percentiles]
+ ),
+ }
+ TRANSFORMS.pop(exp.ArrayJoin)
+ TRANSFORMS.pop(exp.ArraySort)
+ TRANSFORMS.pop(exp.ILike)
+
+ WRAP_DERIVED_VALUES = False
+ CREATE_FUNCTION_RETURN_AS = False
+
+ def cast_sql(self, expression: exp.Cast) -> str:
+ if isinstance(expression.this, exp.Cast) and expression.this.is_type(
+ exp.DataType.Type.JSON
+ ):
+ schema = f"'{self.sql(expression, 'to')}'"
+ return self.func("FROM_JSON", expression.this.this, schema)
+ if expression.to.is_type(exp.DataType.Type.JSON):
+ return self.func("TO_JSON", expression.this)
+
+ return super(Hive.Generator, self).cast_sql(expression)
+
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
+ return super().columndef_sql(
+ expression,
+ sep=": "
+ if isinstance(expression.parent, exp.DataType)
+ and expression.parent.is_type(exp.DataType.Type.STRUCT)
+ else sep,
+ )
+
+ class Tokenizer(Hive.Tokenizer):
+ HEX_STRINGS = [("X'", "'")]
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 4437f82..f2efe32 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -22,6 +22,40 @@ def _date_add_sql(self, expression):
return self.func("DATE", expression.this, modifier)
+def _transform_create(expression: exp.Expression) -> exp.Expression:
+ """Move primary key to a column and enforce auto_increment on primary keys."""
+ schema = expression.this
+
+ if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema):
+ defs = {}
+ primary_key = None
+
+ for e in schema.expressions:
+ if isinstance(e, exp.ColumnDef):
+ defs[e.name] = e
+ elif isinstance(e, exp.PrimaryKey):
+ primary_key = e
+
+ if primary_key and len(primary_key.expressions) == 1:
+ column = defs[primary_key.expressions[0].name]
+ column.append(
+ "constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint())
+ )
+ schema.expressions.remove(primary_key)
+ else:
+ for column in defs.values():
+ auto_increment = None
+ for constraint in column.constraints.copy():
+ if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint):
+ break
+ if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint):
+ auto_increment = constraint
+ if auto_increment:
+ column.constraints.remove(auto_increment)
+
+ return expression
+
+
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@@ -65,8 +99,8 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
- **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.CountIf: count_if_to_sum,
+ exp.Create: transforms.preprocess([_transform_create]),
exp.CurrentDate: lambda *_: "CURRENT_DATE",
exp.CurrentTime: lambda *_: "CURRENT_TIME",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
@@ -80,14 +114,17 @@ class SQLite(Dialect):
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
+ exp.Select: transforms.preprocess(
+ [transforms.eliminate_distinct_on, transforms.eliminate_qualify]
+ ),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
}
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
- exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ k: exp.Properties.Location.UNSUPPORTED
+ for k, v in generator.Generator.PROPERTIES_LOCATION.items()
}
LIMIT_FETCH = "LIMIT"
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index ff19dab..895588a 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -34,6 +34,7 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
+ exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 792c2b4..51e685b 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser
+from sqlglot import exp, generator, parser, transforms
from sqlglot.dialects.dialect import Dialect
@@ -29,6 +29,7 @@ class Tableau(Dialect):
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 331e105..a79eaeb 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@@ -148,6 +148,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 9cf56e1..03de99c 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import re
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
@@ -259,8 +259,8 @@ class TSQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
-
QUOTES = ["'", '"']
+ HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -463,17 +463,18 @@ class TSQL(Dialect):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
- exp.If: rename_func("IIF"),
- exp.NumberToStr: _format_sql,
- exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
+ exp.If: rename_func("IIF"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
+ exp.NumberToStr: _format_sql,
+ exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
"HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
+ exp.TimeToStr: _format_sql,
}
TRANSFORMS.pop(exp.ReturnsProperty)