summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py86
-rw-r--r--sqlglot/dialects/clickhouse.py52
-rw-r--r--sqlglot/dialects/databricks.py15
-rw-r--r--sqlglot/dialects/dialect.py20
-rw-r--r--sqlglot/dialects/doris.py1
-rw-r--r--sqlglot/dialects/drill.py9
-rw-r--r--sqlglot/dialects/duckdb.py38
-rw-r--r--sqlglot/dialects/hive.py55
-rw-r--r--sqlglot/dialects/mysql.py32
-rw-r--r--sqlglot/dialects/oracle.py11
-rw-r--r--sqlglot/dialects/postgres.py38
-rw-r--r--sqlglot/dialects/presto.py54
-rw-r--r--sqlglot/dialects/redshift.py14
-rw-r--r--sqlglot/dialects/snowflake.py78
-rw-r--r--sqlglot/dialects/spark.py10
-rw-r--r--sqlglot/dialects/spark2.py31
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/teradata.py4
-rw-r--r--sqlglot/dialects/trino.py3
-rw-r--r--sqlglot/dialects/tsql.py157
20 files changed, 537 insertions, 176 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 71977dd..d763ed0 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
+ json_keyvalue_comma_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
@@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot")
def _date_add_sql(
data_type: str, kind: str
-) -> t.Callable[[generator.Generator, exp.Expression], str]:
- def func(self, expression):
+) -> t.Callable[[BigQuery.Generator, exp.Expression], str]:
+ def func(self: BigQuery.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.args.get("unit")
unit = exp.var(unit.name.upper() if unit else "DAY")
@@ -40,7 +41,7 @@ def _date_add_sql(
return func
-def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str:
+def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str:
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
@@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V
return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)]))
-def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str:
+def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
this = f"{this.this} <{self.expressions(this)}>"
@@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope
return f"RETURNS {this}"
-def _create_sql(self: generator.Generator, expression: exp.Create) -> str:
+def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
@@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
These are added by the optimizer's qualify_column step.
"""
- from sqlglot.optimizer.scope import Scope
+ from sqlglot.optimizer.scope import find_all_in_scope
if isinstance(expression, exp.Select):
- for unnest in expression.find_all(exp.Unnest):
- if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias:
- for column in Scope(expression).find_all(exp.Column):
- if column.table == unnest.alias:
- column.set("table", None)
+ unnest_aliases = {
+ unnest.alias
+ for unnest in find_all_in_scope(expression, exp.Unnest)
+ if isinstance(unnest.parent, (exp.From, exp.Join))
+ }
+ if unnest_aliases:
+ for column in expression.find_all(exp.Column):
+ if column.table in unnest_aliases:
+ column.set("table", None)
+ elif column.db in unnest_aliases:
+ column.set("db", None)
return expression
@@ -261,6 +268,7 @@ class BigQuery(Dialect):
"TIMESTAMP": TokenType.TIMESTAMPTZ,
"NOT DETERMINISTIC": TokenType.VOLATILE,
"UNKNOWN": TokenType.NULL,
+ "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
KEYWORDS.pop("DIV")
@@ -270,6 +278,8 @@ class BigQuery(Dialect):
LOG_BASE_FIRST = False
LOG_DEFAULTS_TO_LN = True
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": _parse_date,
@@ -299,6 +309,8 @@ class BigQuery(Dialect):
if re.compile(str(seq_get(args, 1))).groups == 1
else None,
),
+ "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)),
+ "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)),
"SPLIT": lambda args: exp.Split(
# https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split
this=seq_get(args, 0),
@@ -346,7 +358,7 @@ class BigQuery(Dialect):
}
def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
- this = super()._parse_table_part(schema=schema)
+ this = super()._parse_table_part(schema=schema) or self._parse_number()
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names
if isinstance(this, exp.Identifier):
@@ -356,6 +368,17 @@ class BigQuery(Dialect):
table_name += f"-{self._prev.text}"
this = exp.Identifier(this=table_name, quoted=this.args.get("quoted"))
+ elif isinstance(this, exp.Literal):
+ table_name = this.name
+
+ if (
+ self._curr
+ and self._prev.end == self._curr.start - 1
+ and self._parse_var(any_token=True)
+ ):
+ table_name += self._prev.text
+
+ this = exp.Identifier(this=table_name, quoted=True)
return this
@@ -374,6 +397,27 @@ class BigQuery(Dialect):
return table
+ def _parse_json_object(self) -> exp.JSONObject:
+ json_object = super()._parse_json_object()
+ array_kv_pair = seq_get(json_object.expressions, 0)
+
+ # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2
+ if (
+ array_kv_pair
+ and isinstance(array_kv_pair.this, exp.Array)
+ and isinstance(array_kv_pair.expression, exp.Array)
+ ):
+ keys = array_kv_pair.this.expressions
+ values = array_kv_pair.expression.expressions
+
+ json_object.set(
+ "expressions",
+ [exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)],
+ )
+
+ return json_object
+
class Generator(generator.Generator):
EXPLICIT_UNION = True
INTERVAL_ALLOWS_PLURAL_FORM = False
@@ -383,6 +427,7 @@ class BigQuery(Dialect):
LIMIT_FETCH = "LIMIT"
RENAME_TABLE_WITH_DB = False
ESCAPE_LINE_BREAK = True
+ NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -405,6 +450,7 @@ class BigQuery(Dialect):
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
+ exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
@@ -428,6 +474,9 @@ class BigQuery(Dialect):
_alias_ordered_group,
]
),
+ exp.SHA2: lambda self, e: self.func(
+ f"SHA256" if e.text("length") == "256" else "SHA512", e.this
+ ),
exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
@@ -591,6 +640,13 @@ class BigQuery(Dialect):
return super().attimezone_sql(expression)
+ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals
+ if expression.is_type("json"):
+ return f"JSON {self.sql(expression, 'this')}"
+
+ return super().cast_sql(expression, safe_prefix=safe_prefix)
+
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
@@ -630,3 +686,9 @@ class BigQuery(Dialect):
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("OPTIONS"))
+
+ def version_sql(self, expression: exp.Version) -> str:
+ if expression.name == "TIMESTAMP":
+ expression = expression.copy()
+ expression.set("this", "SYSTEM_TIME")
+ return super().version_sql(expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index cfde5fd..a38a239 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
var_map_sql,
)
from sqlglot.errors import ParseError
+from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import Token, TokenType
@@ -63,9 +64,23 @@ class ClickHouse(Dialect):
}
class Parser(parser.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
+ "DATE_ADD": lambda args: exp.DateAdd(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ "DATEADD": lambda args: exp.DateAdd(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ "DATE_DIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ "DATEDIFF": lambda args: exp.DateDiff(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
@@ -147,7 +162,7 @@ class ClickHouse(Dialect):
this = self._parse_id_var()
self._match(TokenType.COLON)
- kind = self._parse_types(check_func=False) or (
+ kind = self._parse_types(check_func=False, allow_identifiers=False) or (
self._match_text_seq("IDENTIFIER") and "Identifier"
)
@@ -249,7 +264,7 @@ class ClickHouse(Dialect):
def _parse_func_params(
self, this: t.Optional[exp.Func] = None
- ) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
+ ) -> t.Optional[t.List[exp.Expression]]:
if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
return self._parse_csv(self._parse_lambda)
@@ -267,9 +282,7 @@ class ClickHouse(Dialect):
return self.expression(exp.Quantile, this=params[0], quantile=this)
return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5))
- def _parse_wrapped_id_vars(
- self, optional: bool = False
- ) -> t.List[t.Optional[exp.Expression]]:
+ def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]:
return super()._parse_wrapped_id_vars(optional=True)
def _parse_primary_key(
@@ -292,9 +305,22 @@ class ClickHouse(Dialect):
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
+ NVL2_SUPPORTED = False
+
+ STRING_TYPE_MAPPING = {
+ exp.DataType.Type.CHAR: "String",
+ exp.DataType.Type.LONGBLOB: "String",
+ exp.DataType.Type.LONGTEXT: "String",
+ exp.DataType.Type.MEDIUMBLOB: "String",
+ exp.DataType.Type.MEDIUMTEXT: "String",
+ exp.DataType.Type.TEXT: "String",
+ exp.DataType.Type.VARBINARY: "String",
+ exp.DataType.Type.VARCHAR: "String",
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ **STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
exp.DataType.Type.DATETIME64: "DateTime64",
@@ -328,6 +354,12 @@ class ClickHouse(Dialect):
exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
+ exp.DateAdd: lambda self, e: self.func(
+ "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ ),
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ ),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
@@ -364,6 +396,16 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ # String is the standard ClickHouse type, every other variant is just an alias.
+ # Additionally, any supplied length parameter will be ignored.
+ #
+ # https://clickhouse.com/docs/en/sql-reference/data-types/string
+ if expression.this in self.STRING_TYPE_MAPPING:
+ return "String"
+
+ return super().datatype_sql(expression)
+
def safeconcat_sql(self, expression: exp.SafeConcat) -> str:
# Clickhouse errors out if we try to cast a NULL value to TEXT
expression = expression.copy()
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 2149aca..6ec0487 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, transforms
-from sqlglot.dialects.dialect import parse_date_delta
+from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql
from sqlglot.dialects.spark import Spark
from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql
from sqlglot.tokens import TokenType
@@ -28,6 +28,19 @@ class Databricks(Spark):
**Spark.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
+ exp.DatetimeAdd: lambda self, e: self.func(
+ "TIMESTAMPADD", e.text("unit"), e.expression, e.this
+ ),
+ exp.DatetimeSub: lambda self, e: self.func(
+ "TIMESTAMPADD",
+ e.text("unit"),
+ exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)),
+ e.this,
+ ),
+ exp.DatetimeDiff: lambda self, e: self.func(
+ "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
+ ),
+ exp.DatetimeTrunc: timestamptrunc_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
[
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 132496f..1bfbfef 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -109,8 +109,7 @@ class _Dialect(type):
for k, v in vars(klass).items()
if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__")
},
- "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
- "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
+ "TOKENIZER_CLASS": klass.tokenizer_class,
}
if enum not in ("", "bigquery"):
@@ -345,7 +344,7 @@ def arrow_json_extract_scalar_sql(
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
- return f"[{self.expressions(expression)}]"
+ return f"[{self.expressions(expression, flat=True)}]"
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
@@ -415,9 +414,9 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
- this = self.sql(expression, "this")
- struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True))
- return f"{this}.{struct_key}"
+ return (
+ f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}"
+ )
def var_map_sql(
@@ -722,3 +721,12 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
+
+
+def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str:
+ return self.func("MAX", expression.this)
+
+
+# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
+def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str:
+ return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 160c23c..4b8919c 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -37,7 +37,6 @@ class Doris(MySQL):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
- exp.Coalesce: rename_func("NVL"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 1b2681d..c811c86 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -16,8 +16,8 @@ from sqlglot.dialects.dialect import (
)
-def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -25,7 +25,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
-def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.DATE_FORMAT:
@@ -73,7 +73,6 @@ class Drill(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'"]
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
@@ -81,6 +80,7 @@ class Drill(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
CONCAT_NULL_OUTPUTS_STRING = True
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -95,6 +95,7 @@ class Drill(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 8253b52..684e35e 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
encode_decode_sql,
format_time_lambda,
+ inline_array_sql,
no_comment_column_constraint_sql,
no_properties_sql,
no_safe_divide_sql,
@@ -30,13 +31,13 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
+def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}"
-def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = self.sql(expression, "unit").strip("'") or "DAY"
op = "+" if isinstance(expression, exp.DateAdd) else "-"
@@ -44,7 +45,7 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat
# BigQuery -> DuckDB conversion for the DATE function
-def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
+def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
result = f"CAST({self.sql(expression, 'this')} AS DATE)"
zone = self.sql(expression, "zone")
@@ -58,13 +59,13 @@ def _date_sql(self: generator.Generator, expression: exp.Date) -> str:
return result
-def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
+def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
return f"ARRAY_SORT({self.sql(expression, 'this')})"
-def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str:
+def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str:
this = self.sql(expression, "this")
if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
@@ -79,14 +80,14 @@ def _parse_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
+def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
args = [
f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions
]
return f"{{{', '.join(args)}}}"
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
@@ -97,7 +98,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
-def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
+def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
sql = self.func("TO_JSON", expression.this, expression.args.get("options"))
return f"CAST({sql} AS TEXT)"
@@ -134,6 +135,7 @@ class DuckDB(Dialect):
class Parser(parser.Parser):
CONCAT_NULL_OUTPUTS_STRING = True
+ SUPPORTS_USER_DEFINED_TYPES = False
BITWISE = {
**parser.Parser.BITWISE,
@@ -183,18 +185,12 @@ class DuckDB(Dialect):
),
}
- TYPE_TOKENS = {
- *parser.Parser.TYPE_TOKENS,
- TokenType.UBIGINT,
- TokenType.UINT,
- TokenType.USMALLINT,
- TokenType.UTINYINT,
- }
-
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
- this = super()._parse_types(check_func=check_func, schema=schema)
+ this = super()._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
# DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3)
# See: https://duckdb.org/docs/sql/data_types/numeric
@@ -207,6 +203,9 @@ class DuckDB(Dialect):
return this
+ def _parse_struct_types(self) -> t.Optional[exp.Expression]:
+ return self._parse_field_def()
+
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
if len(aggregations) == 1:
return super()._pivot_column_names(aggregations)
@@ -219,13 +218,14 @@ class DuckDB(Dialect):
LIMIT_FETCH = "LIMIT"
STRUCT_DELIMITER = ("(", ")")
RENAME_TABLE_WITH_DB = False
+ NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
if e.expressions and e.expressions[0].find(exp.Select)
- else rename_func("LIST_VALUE")(self, e),
+ else inline_array_sql(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 584acc6..8b17c06 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -50,7 +50,7 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
@@ -69,7 +69,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS
return self.func(func, expression.this, modified_increment)
-def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
+def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = TIME_DIFF_FACTOR.get(unit)
@@ -87,7 +87,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"{diff_sql}{multiplier_sql}"
-def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str:
+def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
this = expression.this
if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string:
# Since FROM_JSON requires a nested type, we always wrap the json string with
@@ -103,21 +103,21 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
return self.func("TO_JSON", this, expression.args.get("options"))
-def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str:
+def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
-def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
+def _property_sql(self: Hive.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
-def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str:
+def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression))
-def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@@ -125,7 +125,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st
return f"CAST({this} AS DATE)"
-def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str:
+def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@@ -133,13 +133,13 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st
return f"CAST({this} AS TIMESTAMP)"
-def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
+def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
-def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
+def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
@@ -206,6 +206,8 @@ class Hive(Dialect):
"MSCK REPAIR": TokenType.COMMAND,
"REFRESH": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
+ "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT,
+ "VERSION AS OF": TokenType.VERSION_SNAPSHOT,
}
NUMERIC_LITERALS = {
@@ -220,6 +222,7 @@ class Hive(Dialect):
class Parser(parser.Parser):
LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -257,6 +260,11 @@ class Hive(Dialect):
),
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
+ "STR_TO_MAP": lambda args: exp.StrToMap(
+ this=seq_get(args, 0),
+ pair_delim=seq_get(args, 1) or exp.Literal.string(","),
+ key_value_delim=seq_get(args, 2) or exp.Literal.string(":"),
+ ),
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
@@ -313,7 +321,7 @@ class Hive(Dialect):
)
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
"""
Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
@@ -333,7 +341,9 @@ class Hive(Dialect):
Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
"""
- this = super()._parse_types(check_func=check_func, schema=schema)
+ this = super()._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
if this and not schema:
return this.transform(
@@ -345,6 +355,16 @@ class Hive(Dialect):
return this
+ def _parse_partition_and_order(
+ self,
+ ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
+ return (
+ self._parse_csv(self._parse_conjunction)
+ if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
+ else [],
+ super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
+ )
+
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
@@ -354,6 +374,7 @@ class Hive(Dialect):
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
+ NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -376,6 +397,7 @@ class Hive(Dialect):
]
),
exp.Property: _property_sql,
+ exp.AnyValue: rename_func("FIRST"),
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
@@ -402,6 +424,9 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
+ exp.NotNullColumnConstraint: lambda self, e: ""
+ if e.args.get("allow_null")
+ else "NOT NULL",
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
@@ -472,7 +497,7 @@ class Hive(Dialect):
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
elif expression.is_type("float"):
- size_expression = expression.find(exp.DataTypeSize)
+ size_expression = expression.find(exp.DataTypeParam)
if size_expression:
size = int(size_expression.name)
expression = (
@@ -480,3 +505,7 @@ class Hive(Dialect):
)
return super().datatype_sql(expression)
+
+ def version_sql(self, expression: exp.Version) -> str:
+ sql = super().version_sql(expression)
+ return sql.replace("FOR ", "", 1)
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 9ab4ce8..f9249eb 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_scalar_sql,
datestrtodate_sql,
format_time_lambda,
+ json_keyvalue_comma_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@@ -32,7 +33,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
return _parse
-def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str:
+def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit")
@@ -63,12 +64,12 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
+def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
-def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
+def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
remove_chars = self.sql(expression, "expression")
@@ -83,8 +84,8 @@ def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -93,6 +94,9 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
class MySQL(Dialect):
+ # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
+ IDENTIFIERS_CAN_START_WITH_DIGIT = True
+
TIME_FORMAT = "'%Y-%m-%d %T'"
DPIPE_IS_STRING_CONCAT = False
@@ -129,6 +133,7 @@ class MySQL(Dialect):
"LONGTEXT": TokenType.LONGTEXT,
"MEDIUMBLOB": TokenType.MEDIUMBLOB,
"MEDIUMTEXT": TokenType.MEDIUMTEXT,
+ "MEDIUMINT": TokenType.MEDIUMINT,
"MEMBER OF": TokenType.MEMBER_OF,
"SEPARATOR": TokenType.SEPARATOR,
"START": TokenType.BEGIN,
@@ -136,6 +141,7 @@ class MySQL(Dialect):
"SIGNED INTEGER": TokenType.BIGINT,
"UNSIGNED": TokenType.UBIGINT,
"UNSIGNED INTEGER": TokenType.UBIGINT,
+ "YEAR": TokenType.YEAR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
"_BIG5": TokenType.INTRODUCER,
@@ -185,6 +191,8 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
TokenType.DATABASE,
@@ -492,6 +500,17 @@ class MySQL(Dialect):
return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES")
+ def _parse_type(self) -> t.Optional[exp.Expression]:
+ # mysql binary is special and can work anywhere, even in order by operations
+ # it operates like a no paren func
+ if self._match(TokenType.BINARY, advance=False):
+ data_type = self._parse_types(check_func=True, allow_identifiers=False)
+
+ if isinstance(data_type, exp.DataType):
+ return self.expression(exp.Cast, this=self._parse_column(), to=data_type)
+
+ return super()._parse_type()
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@@ -500,6 +519,7 @@ class MySQL(Dialect):
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
+ NVL2_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -515,6 +535,7 @@ class MySQL(Dialect):
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
@@ -524,6 +545,7 @@ class MySQL(Dialect):
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
+ exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 1f63e9f..279ed31 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -8,7 +8,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
+def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
this = self._parse_string()
passing = None
@@ -22,7 +22,7 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable:
by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
if self._match_text_seq("COLUMNS"):
- columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True)))
+ columns = self._parse_csv(self._parse_field_def)
return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
@@ -78,6 +78,10 @@ class Oracle(Dialect):
)
}
+ # SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT ..
+ # Reference: https://stackoverflow.com/a/336455
+ DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
+
def _parse_column(self) -> t.Optional[exp.Expression]:
column = super()._parse_column()
if column:
@@ -129,7 +133,6 @@ class Oracle(Dialect):
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.ILike: no_ilike_sql,
- exp.Coalesce: rename_func("NVL"),
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
@@ -162,7 +165,7 @@ class Oracle(Dialect):
return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}"
class Tokenizer(tokens.Tokenizer):
- VAR_SINGLE_TOKENS = {"@"}
+ VAR_SINGLE_TOKENS = {"@", "$", "#"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 73ca4e5..c26e121 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
datestrtodate_sql,
@@ -39,8 +40,8 @@ DATE_DIFF_FACTOR = {
}
-def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
expression = expression.copy()
this = self.sql(expression, "this")
@@ -56,7 +57,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e
return func
-def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
+def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
@@ -82,7 +83,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
return f"CAST({unit} AS BIGINT)"
-def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
+def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str:
this = self.sql(expression, "this")
start = self.sql(expression, "start")
length = self.sql(expression, "length")
@@ -93,7 +94,7 @@ def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str:
return f"SUBSTRING({this}{from_part}{for_part})"
-def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@@ -107,7 +108,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
return f"STRING_AGG({self.format_args(this, separator)}{order})"
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
@@ -254,6 +255,7 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
+ "@@": TokenType.DAT,
"@>": TokenType.AT_GT,
"<@": TokenType.LT_AT,
"BEGIN": TokenType.COMMAND,
@@ -273,6 +275,18 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
+ "OID": TokenType.OBJECT_IDENTIFIER,
+ "REGCLASS": TokenType.OBJECT_IDENTIFIER,
+ "REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
+ "REGCONFIG": TokenType.OBJECT_IDENTIFIER,
+ "REGDICTIONARY": TokenType.OBJECT_IDENTIFIER,
+ "REGNAMESPACE": TokenType.OBJECT_IDENTIFIER,
+ "REGOPER": TokenType.OBJECT_IDENTIFIER,
+ "REGOPERATOR": TokenType.OBJECT_IDENTIFIER,
+ "REGPROC": TokenType.OBJECT_IDENTIFIER,
+ "REGPROCEDURE": TokenType.OBJECT_IDENTIFIER,
+ "REGROLE": TokenType.OBJECT_IDENTIFIER,
+ "REGTYPE": TokenType.OBJECT_IDENTIFIER,
}
SINGLE_TOKENS = {
@@ -312,6 +326,9 @@ class Postgres(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
+ TokenType.DAT: lambda self, this: self.expression(
+ exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
+ ),
TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
}
@@ -343,6 +360,7 @@ class Postgres(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"
TYPE_MAPPING = {
@@ -357,6 +375,8 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
+ exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.Explode: rename_func("UNNEST"),
@@ -416,3 +436,9 @@ class Postgres(Dialect):
expression.set("this", exp.paren(expression.this, copy=False))
return super().bracket_sql(expression)
+
+ def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
+ this = self.sql(expression, "this")
+ expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions]
+ sql = " OR ".join(expressions)
+ return f"({sql})" if len(expressions) > 1 else sql
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 078da0b..4b54e95 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
-def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str:
+def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str:
accuracy = expression.args.get("accuracy")
accuracy = ", " + self.sql(accuracy) if accuracy else ""
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str:
+def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
expression = expression.copy()
return self.sql(
@@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -
return self.lateral_sql(expression)
-def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str:
+def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str:
regex = r"(\w)(\w*)"
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
-def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
+def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str:
if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
@@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str:
return self.func("ARRAY_SORT", expression.this, comparator)
-def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
+def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
if isinstance(expression.parent, exp.Property):
columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
@@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str:
return self.schema_sql(expression)
-def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str:
+def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
def _str_to_time_sql(
- self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+ self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
-def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
+def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
-def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str:
+def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
this = expression.this
if not isinstance(this, exp.CurrentDate):
@@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
return expression
+def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str:
+ """
+ Trino doesn't support FIRST / LAST as functions, but they're valid in the context
+ of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases
+ they're converted into an ARBITRARY call.
+
+ Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions
+ """
+ if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize):
+ return self.function_fallback_sql(expression)
+
+ return rename_func("ARBITRARY")(self, expression)
+
+
class Presto(Dialect):
INDEX_OFFSET = 1
NULL_ORDERING = "nulls_are_last"
@@ -178,6 +192,7 @@ class Presto(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "ARBITRARY": exp.AnyValue.from_arg_list,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
@@ -205,7 +220,14 @@ class Presto(Dialect):
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
+ "REGEXP_REPLACE": lambda args: exp.RegexpReplace(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ replacement=seq_get(args, 2) or exp.Literal.string(""),
+ ),
+ "ROW": exp.Struct.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
+ "SPLIT_TO_MAP": exp.StrToMap.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
@@ -225,6 +247,7 @@ class Presto(Dialect):
QUERY_HINTS = False
IS_BOOL_ALLOWED = False
TZ_TO_WITH_TIME_ZONE = True
+ NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@@ -242,10 +265,13 @@ class Presto(Dialect):
exp.DataType.Type.TIMETZ: "TIME",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.STRUCT: "ROW",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
+ exp.DataType.Type.DATETIME64: "TIMESTAMP",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: rename_func("ARBITRARY"),
exp.ApproxDistinct: _approx_distinct_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
@@ -268,15 +294,23 @@ class Presto(Dialect):
),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
+ exp.DateSub: lambda self, e: self.func(
+ "DATE_ADD",
+ exp.Literal.string(e.text("unit") or "day"),
+ e.expression * -1,
+ e.this,
+ ),
exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"),
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)",
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
+ exp.First: _first_last_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
+ exp.Last: _first_last_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@@ -301,8 +335,10 @@ class Presto(Dialect):
exp.SortArray: _no_sort_array,
exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
+ exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.Struct: rename_func("ROW"),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.TimestampTrunc: timestamptrunc_sql,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 30731e1..351c5df 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -13,7 +13,7 @@ from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
-def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
+def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
@@ -37,6 +37,8 @@ class Redshift(Postgres):
}
class Parser(Postgres.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
+
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"ADD_MONTHS": lambda args: exp.DateAdd(
@@ -55,9 +57,11 @@ class Redshift(Postgres):
}
def _parse_types(
- self, check_func: bool = False, schema: bool = False
+ self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
- this = super()._parse_types(check_func=check_func, schema=schema)
+ this = super()._parse_types(
+ check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
+ )
if (
isinstance(this, exp.DataType)
@@ -100,6 +104,7 @@ class Redshift(Postgres):
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
+ NVL2_SUPPORTED = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -142,6 +147,9 @@ class Redshift(Postgres):
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
TRANSFORMS.pop(exp.Pow)
+ # Redshift supports ANY_VALUE(..)
+ TRANSFORMS.pop(exp.AnyValue)
+
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def with_properties(self, properties: exp.Properties) -> str:
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 9733a85..8d8183c 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
+def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) ->
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]:
+def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
this = self._parse_var() or self._parse_type()
if not this:
@@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
+def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
if expression.is_type("array"):
return "ARRAY"
elif expression.is_type("map"):
@@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
return self.datatype_sql(expression)
+def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
+ flag = expression.text("flag")
+
+ if "i" not in flag:
+ flag += "i"
+
+ return self.func(
+ "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag)
+ )
+
+
def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
@@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
return regexp_replace
+def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]:
+ def _parse(self: Snowflake.Parser) -> exp.Show:
+ return self._parse_show_snowflake(*args, **kwargs)
+
+ return _parse
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
@@ -216,6 +234,7 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ SUPPORTS_USER_DEFINED_TYPES = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -230,6 +249,7 @@ class Snowflake(Dialect):
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
+ "LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
"REGEXP_REPLACE": _parse_regexp_replace,
@@ -250,11 +270,6 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS.pop("TRIM")
- FUNC_TOKENS = {
- *parser.Parser.FUNC_TOKENS,
- TokenType.TABLE,
- }
-
COLUMN_OPERATORS = {
**parser.Parser.COLUMN_OPERATORS,
TokenType.COLON: lambda self, this, path: self.expression(
@@ -281,6 +296,16 @@ class Snowflake(Dialect):
),
}
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.SHOW: lambda self: self._parse_show(),
+ }
+
+ SHOW_PARSERS = {
+ "PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ }
+
def _parse_id_var(
self,
any_token: bool = True,
@@ -296,8 +321,24 @@ class Snowflake(Dialect):
return super()._parse_id_var(any_token=any_token, tokens=tokens)
+ def _parse_show_snowflake(self, this: str) -> exp.Show:
+ scope = None
+ scope_kind = None
+
+ if self._match(TokenType.IN):
+ if self._match_text_seq("ACCOUNT"):
+ scope_kind = "ACCOUNT"
+ elif self._match_set(self.DB_CREATABLES):
+ scope_kind = self._prev.text
+ if self._curr:
+ scope = self._parse_table()
+ elif self._curr:
+ scope_kind = "TABLE"
+ scope = self._parse_table()
+
+ return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
+
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
RAW_STRINGS = ["$$"]
@@ -331,6 +372,8 @@ class Snowflake(Dialect):
VAR_SINGLE_TOKENS = {"$"}
+ COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
+
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
MATCHED_BY_SOURCE = False
@@ -355,6 +398,7 @@ class Snowflake(Dialect):
exp.DataType: _datatype_sql,
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.Extract: rename_func("DATE_PART"),
+ exp.GroupConcat: rename_func("LISTAGG"),
exp.If: rename_func("IFF"),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
@@ -362,6 +406,7 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.RegexpILike: _regexpilike_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StarMap: rename_func("OBJECT_CONSTRUCT"),
exp.StartsWith: rename_func("STARTSWITH"),
@@ -373,6 +418,7 @@ class Snowflake(Dialect):
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
),
+ exp.Stuff: rename_func("INSERT"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@@ -403,6 +449,16 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def show_sql(self, expression: exp.Show) -> str:
+ scope = self.sql(expression, "scope")
+ scope = f" {scope}" if scope else ""
+
+ scope_kind = self.sql(expression, "scope_kind")
+ if scope_kind:
+ scope_kind = f" IN {scope_kind}"
+
+ return f"SHOW {expression.name}{scope_kind}{scope}"
+
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
# generate default values as necessary to ensure the transpilation is correct
@@ -436,7 +492,9 @@ class Snowflake(Dialect):
kind_value = expression.args.get("kind") or "TABLE"
kind = f" {kind_value}" if kind_value else ""
this = f" {self.sql(expression, 'this')}"
- return f"DESCRIBE{kind}{this}"
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" {expressions}" if expressions else ""
+ return f"DESCRIBE{kind}{this}{expressions}"
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 7c8982b..a4435f6 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -38,9 +38,15 @@ class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
+ "ANY_VALUE": lambda args: exp.AnyValue(
+ this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
+ ),
"DATEDIFF": _parse_datediff,
}
+ FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
+ FUNCTION_PARSERS.pop("ANY_VALUE")
+
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
@@ -56,9 +62,13 @@ class Spark(Spark2):
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
}
+ TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
+ def anyvalue_sql(self, expression: exp.AnyValue) -> str:
+ return self.function_fallback_sql(expression)
+
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index ceb48f8..4489b6b 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -15,7 +15,7 @@ from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
+def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
kind = e.args["kind"]
properties = e.args.get("properties")
@@ -31,17 +31,21 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str:
return create_with_partitions_sql(self, e)
-def _map_sql(self: Hive.Generator, expression: exp.Map) -> str:
- keys = self.sql(expression.args["keys"])
- values = self.sql(expression.args["values"])
- return f"MAP_FROM_ARRAYS({keys}, {values})"
+def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
+ keys = expression.args.get("keys")
+ values = expression.args.get("values")
+
+ if not keys or not values:
+ return "MAP()"
+
+ return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
-def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
+def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.DATE_FORMAT:
@@ -49,7 +53,7 @@ def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str:
return f"TO_DATE({this}, {time_format})"
-def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str:
+def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale is None:
@@ -110,6 +114,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str:
+ if expression.expression.args.get("with"):
+ expression = expression.copy()
+ expression.set("with", expression.expression.args.pop("with"))
+ return self.insert_sql(expression)
+
+
class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
@@ -169,10 +180,7 @@ class Spark2(Hive):
class Generator(Hive.Generator):
QUERY_HINTS = True
-
- TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING,
- }
+ NVL2_SUPPORTED = True
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
@@ -197,6 +205,7 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
+ exp.Insert: _insert_sql,
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 90b774e..7bfdf1c 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
@@ -18,7 +19,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
-def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
+def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
modifier = expression.expression
modifier = modifier.name if modifier.is_string else self.sql(modifier)
unit = expression.args.get("unit")
@@ -78,6 +79,7 @@ class SQLite(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ NVL2_SUPPORTED = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -103,6 +105,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
exp.Concat: concat_to_dpipe_sql,
exp.CountIf: count_if_to_sum,
exp.Create: transforms.preprocess([_transform_create]),
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 2be1a62..163cc13 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -95,6 +95,9 @@ class Teradata(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS,
+ TokenType.DATABASE: lambda self: self.expression(
+ exp.Use, this=self._parse_table(schema=False)
+ ),
TokenType.REPLACE: lambda self: self._parse_create(),
}
@@ -165,6 +168,7 @@ class Teradata(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index af0f78d..0c953a1 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -13,3 +13,6 @@ class Trino(Presto):
class Tokenizer(Presto.Tokenizer):
HEX_STRINGS = [("X'", "'")]
+
+ class Parser(Presto.Parser):
+ SUPPORTS_USER_DEFINED_TYPES = False
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 131307f..b26f499 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -7,6 +7,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
max_or_greatest,
min_or_least,
parse_date_delta,
@@ -79,22 +80,23 @@ def _format_time_lambda(
def _parse_format(args: t.List) -> exp.Expression:
- assert len(args) == 2
+ this = seq_get(args, 0)
+ fmt = seq_get(args, 1)
+ culture = seq_get(args, 2)
- fmt = args[1]
- number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
+ number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name))
if number_fmt:
- return exp.NumberToStr(this=args[0], format=fmt)
+ return exp.NumberToStr(this=this, format=fmt, culture=culture)
- return exp.TimeToStr(
- this=args[0],
- format=exp.Literal.string(
+ if fmt:
+ fmt = exp.Literal.string(
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
if len(fmt.name) == 1
else format_time(fmt.name, TSQL.TIME_MAPPING)
- ),
- )
+ )
+
+ return exp.TimeToStr(this=this, format=fmt, culture=culture)
def _parse_eomonth(args: t.List) -> exp.Expression:
@@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
def generate_date_delta_with_unit_sql(
- self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
+ self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
) -> str:
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
return self.func(func, expression.text("unit"), expression.expression, expression.this)
-def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
+def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
expression.args["format"]
if isinstance(expression, exp.NumberToStr)
@@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
)
)
)
- return self.func("FORMAT", expression.this, fmt)
+ return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
-def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
this = expression.this
@@ -332,10 +334,12 @@ class TSQL(Dialect):
"SQL_VARIANT": TokenType.VARIANT,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
+ "UPDATE STATISTICS": TokenType.COMMAND,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
+ "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
class Parser(parser.Parser):
@@ -395,7 +399,9 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
- def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+
+ def _parse_projections(self) -> t.List[exp.Expression]:
"""
T-SQL supports the syntax alias = expression in the SELECT's projection list,
so we transform all parsed Selects to convert their EQ projections into Aliases.
@@ -458,43 +464,6 @@ class TSQL(Dialect):
return self._parse_as_command(self._prev)
- def _parse_system_time(self) -> t.Optional[exp.Expression]:
- if not self._match_text_seq("FOR", "SYSTEM_TIME"):
- return None
-
- if self._match_text_seq("AS", "OF"):
- system_time = self.expression(
- exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
- )
- elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
- kind = self._prev.text
- this = self._parse_bitwise()
- self._match_texts(("TO", "AND"))
- expression = self._parse_bitwise()
- system_time = self.expression(
- exp.SystemTime, this=this, expression=expression, kind=kind
- )
- elif self._match_text_seq("CONTAINED", "IN"):
- args = self._parse_wrapped_csv(self._parse_bitwise)
- system_time = self.expression(
- exp.SystemTime,
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- kind="CONTAINED IN",
- )
- elif self._match(TokenType.ALL):
- system_time = self.expression(exp.SystemTime, kind="ALL")
- else:
- system_time = None
- self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
-
- return system_time
-
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
- table = super()._parse_table_parts(schema=schema)
- table.set("system_time", self._parse_system_time())
- return table
-
def _parse_returns(self) -> exp.ReturnsProperty:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
@@ -589,14 +558,36 @@ class TSQL(Dialect):
return create
+ def _parse_if(self) -> t.Optional[exp.Expression]:
+ index = self._index
+
+ if self._match_text_seq("OBJECT_ID"):
+ self._parse_wrapped_csv(self._parse_string)
+ if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP):
+ return self._parse_drop(exists=True)
+ self._retreat(index)
+
+ return super()._parse_if()
+
+ def _parse_unique(self) -> exp.UniqueColumnConstraint:
+ return self.expression(
+ exp.UniqueColumnConstraint,
+ this=None
+ if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
+ else self._parse_schema(self._parse_id_var(any_token=False)),
+ )
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
+ NVL2_SUPPORTED = False
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.INT: "INTEGER",
@@ -607,6 +598,8 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
+ exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@@ -651,25 +644,44 @@ class TSQL(Dialect):
return sql
- def offset_sql(self, expression: exp.Offset) -> str:
- return f"{super().offset_sql(expression)} ROWS"
+ def create_sql(self, expression: exp.Create) -> str:
+ expression = expression.copy()
+ kind = self.sql(expression, "kind").upper()
+ exists = expression.args.pop("exists", None)
+ sql = super().create_sql(expression)
+
+ if exists:
+ table = expression.find(exp.Table)
+ identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
+ if kind == "SCHEMA":
+ sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')"""
+ elif kind == "TABLE":
+ sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')"""
+ elif kind == "INDEX":
+ index = self.sql(exp.Literal.string(expression.this.text("this")))
+ sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')"""
+ elif expression.args.get("replace"):
+ sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
- def systemtime_sql(self, expression: exp.SystemTime) -> str:
- kind = expression.args["kind"]
- if kind == "ALL":
- return "FOR SYSTEM_TIME ALL"
+ return sql
- start = self.sql(expression, "this")
- if kind == "AS OF":
- return f"FOR SYSTEM_TIME AS OF {start}"
+ def offset_sql(self, expression: exp.Offset) -> str:
+ return f"{super().offset_sql(expression)} ROWS"
- end = self.sql(expression, "expression")
- if kind == "FROM":
- return f"FOR SYSTEM_TIME FROM {start} TO {end}"
- if kind == "BETWEEN":
- return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
+ def version_sql(self, expression: exp.Version) -> str:
+ name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name
+ this = f"FOR {name}"
+ expr = expression.expression
+ kind = expression.text("kind")
+ if kind in ("FROM", "BETWEEN"):
+ args = expr.expressions
+ sep = "TO" if kind == "FROM" else "AND"
+ expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}"
+ else:
+ expr_sql = self.sql(expr)
- return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
+ expr_sql = f" {expr_sql}" if expr_sql else ""
+ return f"{this} {kind}{expr_sql}"
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
table = expression.args.get("table")
@@ -713,3 +725,16 @@ class TSQL(Dialect):
identifier = f"#{identifier}"
return identifier
+
+ def constraint_sql(self, expression: exp.Constraint) -> str:
+ this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True, sep=" ")
+ return f"CONSTRAINT {this} {expressions}"
+
+ # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
+ start = self.sql(expression, "start") or "1"
+ increment = self.sql(expression, "increment") or "1"
+ return f"IDENTITY({start}, {increment})"