summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dialects/bigquery.py145
-rw-r--r--sqlglot/dialects/clickhouse.py37
-rw-r--r--sqlglot/dialects/databricks.py29
-rw-r--r--sqlglot/dialects/dialect.py88
-rw-r--r--sqlglot/dialects/doris.py15
-rw-r--r--sqlglot/dialects/drill.py24
-rw-r--r--sqlglot/dialects/duckdb.py84
-rw-r--r--sqlglot/dialects/hive.py61
-rw-r--r--sqlglot/dialects/mysql.py31
-rw-r--r--sqlglot/dialects/oracle.py62
-rw-r--r--sqlglot/dialects/postgres.py55
-rw-r--r--sqlglot/dialects/presto.py80
-rw-r--r--sqlglot/dialects/redshift.py18
-rw-r--r--sqlglot/dialects/snowflake.py196
-rw-r--r--sqlglot/dialects/spark.py8
-rw-r--r--sqlglot/dialects/spark2.py64
-rw-r--r--sqlglot/dialects/starrocks.py10
-rw-r--r--sqlglot/dialects/tableau.py4
-rw-r--r--sqlglot/dialects/teradata.py2
-rw-r--r--sqlglot/dialects/tsql.py63
-rw-r--r--sqlglot/expressions.py74
-rw-r--r--sqlglot/generator.py134
-rw-r--r--sqlglot/optimizer/normalize.py4
-rw-r--r--sqlglot/optimizer/qualify.py14
-rw-r--r--sqlglot/optimizer/qualify_columns.py20
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py2
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
-rw-r--r--sqlglot/parser.py49
-rw-r--r--sqlglot/schema.py6
-rw-r--r--sqlglot/tokens.py3
32 files changed, 705 insertions, 683 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 133979a..308b639 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -1030,7 +1030,7 @@ def posexplode_outer(col: ColumnOrName) -> Column:
def get_json_object(col: ColumnOrName, path: str) -> Column:
- return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path))
+ return Column.invoke_expression_over_column(col, expression.JSONExtract, expression=lit(path))
def json_tuple(col: ColumnOrName, *fields: str) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index c0191b2..f867617 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -12,13 +12,14 @@ from sqlglot.dialects.dialect import (
binary_from_function,
date_add_interval_sql,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
+ filter_array_using_unnest,
if_sql,
inline_array_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
- parse_date_delta_with_interval,
+ build_date_delta_with_interval,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
@@ -37,56 +38,33 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
+ structs = []
alias = expression.args.get("alias")
+ for tup in expression.find_all(exp.Tuple):
+ field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
+ expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)]
+ structs.append(exp.Struct(expressions=expressions))
- return self.unnest_sql(
- exp.Unnest(
- expressions=[
- exp.array(
- *(
- exp.Struct(
- expressions=[
- exp.alias_(value, column_name)
- for value, column_name in zip(
- t.expressions,
- (
- alias.columns
- if alias and alias.columns
- else (f"_c{i}" for i in range(len(t.expressions)))
- ),
- )
- ]
- )
- for t in expression.find_all(exp.Tuple)
- ),
- copy=False,
- )
- ]
- )
- )
+ return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
- this = f"{this.this} <{self.expressions(this)}>"
+ this = f"{self.sql(this, 'this')} <{self.expressions(this)}>"
else:
this = self.sql(this)
return f"RETURNS {this}"
def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
- kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
-
- if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
+ if expression.kind == "FUNCTION" and returns and returns.args.get("is_table"):
expression.set("kind", "TABLE FUNCTION")
if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
expression.set("expression", expression.expression.this)
- return self.create_sql(expression)
-
return self.create_sql(expression)
@@ -132,11 +110,10 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
if isinstance(select, exp.Alias)
}
- for e in group.expressions:
- alias = aliases.get(e)
-
+ for grouped in group.expressions:
+ alias = aliases.get(grouped)
if alias:
- e.replace(exp.column(alias))
+ grouped.replace(exp.column(alias))
return expression
@@ -168,24 +145,24 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
return expression
-def _parse_parse_timestamp(args: t.List) -> exp.StrToTime:
- this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
+def _build_parse_timestamp(args: t.List) -> exp.StrToTime:
+ this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
this.set("zone", seq_get(args, 2))
return this
-def _parse_timestamp(args: t.List) -> exp.Timestamp:
+def _build_timestamp(args: t.List) -> exp.Timestamp:
timestamp = exp.Timestamp.from_arg_list(args)
timestamp.set("with_tz", True)
return timestamp
-def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
+def _build_date(args: t.List) -> exp.Date | exp.DateFromParts:
expr_type = exp.DateFromParts if len(args) == 3 else exp.Date
return expr_type.from_arg_list(args)
-def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
+def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5:
# TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation
arg = seq_get(args, 0)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
@@ -214,18 +191,20 @@ def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) ->
def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
+ timestamp = expression.this
+
if scale in (None, exp.UnixToTime.SECONDS):
- return f"TIMESTAMP_SECONDS({timestamp})"
+ return self.func("TIMESTAMP_SECONDS", timestamp)
if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
+ return self.func("TIMESTAMP_MILLIS", timestamp)
if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
+ return self.func("TIMESTAMP_MICROS", timestamp)
- return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))"
+ unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64")
+ return self.func("TIMESTAMP_SECONDS", unix_seconds)
-def _parse_time(args: t.List) -> exp.Func:
+def _build_time(args: t.List) -> exp.Func:
if len(args) == 1:
return exp.TsOrDsToTime(this=args[0])
if len(args) == 3:
@@ -323,6 +302,7 @@ class BigQuery(Dialect):
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"DECLARE": TokenType.COMMAND,
+ "EXCEPTION": TokenType.COMMAND,
"FLOAT64": TokenType.DOUBLE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"MODEL": TokenType.MODEL,
@@ -340,15 +320,15 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE": _parse_date,
- "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
- "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DATE": _build_date,
+ "DATE_ADD": build_date_delta_with_interval(exp.DateAdd),
+ "DATE_SUB": build_date_delta_with_interval(exp.DateSub),
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
- "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
- "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
+ "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
+ "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
"FORMAT_DATE": lambda args: exp.TimeToStr(
this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0)
@@ -358,11 +338,11 @@ class BigQuery(Dialect):
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
),
"MD5": exp.MD5Digest.from_arg_list,
- "TO_HEX": _parse_to_hex,
- "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
+ "TO_HEX": _build_to_hex,
+ "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
- "PARSE_TIMESTAMP": _parse_parse_timestamp,
+ "PARSE_TIMESTAMP": _build_parse_timestamp,
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@@ -378,12 +358,12 @@ class BigQuery(Dialect):
this=seq_get(args, 0),
expression=seq_get(args, 1) or exp.Literal.string(","),
),
- "TIME": _parse_time,
- "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
- "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
- "TIMESTAMP": _parse_timestamp,
- "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
- "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
+ "TIME": _build_time,
+ "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd),
+ "TIME_SUB": build_date_delta_with_interval(exp.TimeSub),
+ "TIMESTAMP": _build_timestamp,
+ "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd),
+ "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub),
"TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
),
@@ -424,7 +404,7 @@ class BigQuery(Dialect):
}
RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy()
- RANGE_PARSERS.pop(TokenType.OVERLAPS, None)
+ RANGE_PARSERS.pop(TokenType.OVERLAPS)
NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN}
@@ -551,6 +531,7 @@ class BigQuery(Dialect):
NULL_ORDERING_SUPPORTED = False
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
+ CAN_IMPLEMENT_ARRAY_ANY = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -558,6 +539,7 @@ class BigQuery(Dialect):
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayContains: _array_contains_sql,
+ exp.ArrayFilter: filter_array_using_unnest,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CollateProperty: lambda self, e: (
@@ -565,12 +547,14 @@ class BigQuery(Dialect):
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}"
),
+ exp.Commit: lambda *_: "COMMIT TRANSACTION",
exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
- exp.DateDiff: lambda self,
- e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", e.this, e.expression, e.unit or "DAY"
+ ),
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
@@ -602,6 +586,7 @@ class BigQuery(Dialect):
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
+ exp.Rollback: lambda *_: "ROLLBACK TRANSACTION",
exp.Select: transforms.preprocess(
[
transforms.explode_to_unnest(),
@@ -617,8 +602,7 @@ class BigQuery(Dialect):
exp.StabilityProperty: lambda self, e: (
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
- exp.StrToDate: lambda self,
- e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
+ exp.StrToDate: lambda self, e: self.func("PARSE_DATE", self.format_time(e), e.this),
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
@@ -629,6 +613,7 @@ class BigQuery(Dialect):
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
+ exp.Transaction: lambda *_: "BEGIN TRANSACTION",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
@@ -778,12 +763,8 @@ class BigQuery(Dialect):
}
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
- if isinstance(expression.this, exp.TsOrDsToDate):
- this: exp.Expression = expression.this
- else:
- this = expression
-
- return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})"
+ this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression
+ return self.func("FORMAT_DATE", self.format_time(expression), this.this)
def struct_sql(self, expression: exp.Struct) -> str:
args = []
@@ -820,11 +801,6 @@ class BigQuery(Dialect):
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
- def cte_sql(self, expression: exp.CTE) -> str:
- if expression.alias_column_names:
- self.unsupported("Column names in CTE definition are not supported.")
- return super().cte_sql(expression)
-
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
@@ -862,25 +838,16 @@ class BigQuery(Dialect):
return f"{this}[{expressions_sql}]"
- def transaction_sql(self, *_) -> str:
- return "BEGIN TRANSACTION"
-
- def commit_sql(self, *_) -> str:
- return "COMMIT TRANSACTION"
-
- def rollback_sql(self, *_) -> str:
- return "ROLLBACK TRANSACTION"
-
def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(expression)
def except_op(self, expression: exp.Except) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def intersect_op(self, expression: exp.Intersect) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index d7be64c..05d6a03 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -11,13 +11,12 @@ from sqlglot.dialects.dialect import (
json_extract_segments,
json_path_key_only_name,
no_pivot_sql,
- parse_json_extract_path,
+ build_json_extract_path,
rename_func,
var_map_sql,
)
from sqlglot.errors import ParseError
from sqlglot.helper import is_int, seq_get
-from sqlglot.parser import parse_var_map
from sqlglot.tokens import Token, TokenType
@@ -26,9 +25,9 @@ def _lower_func(sql: str) -> str:
return sql[:index].lower() + sql[index:]
-def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str:
- quantile = e.args["quantile"]
- args = f"({self.sql(e, 'this')})"
+def _quantile_sql(self: ClickHouse.Generator, expression: exp.Quantile) -> str:
+ quantile = expression.args["quantile"]
+ args = f"({self.sql(expression, 'this')})"
if isinstance(quantile, exp.Array):
func = self.func("quantiles", *quantile)
@@ -38,7 +37,7 @@ def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str:
return func + args
-def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
+def _build_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
if len(args) == 1:
return exp.CountIf(this=seq_get(args, 0))
@@ -111,7 +110,7 @@ class ClickHouse(Dialect):
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
"ARRAYSUM": exp.ArraySum.from_arg_list,
- "COUNTIF": _parse_count_if,
+ "COUNTIF": _build_count_if,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
@@ -124,10 +123,10 @@ class ClickHouse(Dialect):
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
- "JSONEXTRACTSTRING": parse_json_extract_path(
+ "JSONEXTRACTSTRING": build_json_extract_path(
exp.JSONExtractScalar, zero_based_indexing=False
),
- "MAP": parse_var_map,
+ "MAP": parser.build_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
"RANDCANONICAL": exp.Rand.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
@@ -417,9 +416,9 @@ class ClickHouse(Dialect):
self, skip_join_token: bool = False, parse_bracket: bool = False
) -> t.Optional[exp.Join]:
join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True)
-
if join:
join.set("global", join.args.pop("method", None))
+
return join
def _parse_function(
@@ -516,6 +515,7 @@ class ClickHouse(Dialect):
TABLESAMPLE_SIZE_IS_ROWS = False
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
+ CAN_IMPLEMENT_ARRAY_ANY = True
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -576,6 +576,8 @@ class ClickHouse(Dialect):
**generator.Generator.TRANSFORMS,
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
+ exp.ArrayFilter: lambda self, e: self.func("arrayFilter", e.expression, e.this),
+ exp.ArraySize: rename_func("LENGTH"),
exp.ArraySum: rename_func("arraySum"),
exp.ArgMax: arg_max_or_min_no_count("argMax"),
exp.ArgMin: arg_max_or_min_no_count("argMin"),
@@ -597,12 +599,13 @@ class ClickHouse(Dialect):
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
- exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
+ exp.RegexpLike: lambda self, e: self.func("match", e.this, e.expression),
exp.Rand: rename_func("randCanonical"),
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
- exp.StrPosition: lambda self,
- e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
+ exp.StrPosition: lambda self, e: self.func(
+ "position", e.this, e.args.get("substr"), e.args.get("position")
+ ),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
}
@@ -652,6 +655,7 @@ class ClickHouse(Dialect):
this = expression.left
else:
return default(expression)
+
return prefix + self.func("has", arr.this.unnest(), this)
def eq_sql(self, expression: exp.EQ) -> str:
@@ -663,7 +667,7 @@ class ClickHouse(Dialect):
def regexpilike_sql(self, expression: exp.RegexpILike) -> str:
# Manually add a flag to make the search case-insensitive
regex = self.func("CONCAT", "'(?i)'", expression.expression)
- return f"match({self.format_args(expression.this, regex)})"
+ return self.func("match", expression.this, regex)
def datatype_sql(self, expression: exp.DataType) -> str:
# String is the standard ClickHouse type, every other variant is just an alias.
@@ -717,8 +721,9 @@ class ClickHouse(Dialect):
return f"ON CLUSTER {self.sql(expression, 'this')}"
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
- kind = self.sql(expression, "kind").upper()
- if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME):
+ if expression.kind in self.ON_CLUSTER_TARGETS and locations.get(
+ exp.Properties.Location.POST_NAME
+ ):
this_name = self.sql(expression.this, "this")
this_properties = " ".join(
[self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]]
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 20907db..96eff18 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -3,13 +3,19 @@ from __future__ import annotations
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
date_delta_sql,
- parse_date_delta,
+ build_date_delta,
timestamptrunc_sql,
)
from sqlglot.dialects.spark import Spark
from sqlglot.tokens import TokenType
+def _timestamp_diff(
+ self: Databricks.Generator, expression: exp.DatetimeDiff | exp.TimestampDiff
+) -> str:
+ return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this)
+
+
class Databricks(Spark):
SAFE_DIVISION = False
@@ -19,10 +25,10 @@ class Databricks(Spark):
FUNCTIONS = {
**Spark.Parser.FUNCTIONS,
- "DATEADD": parse_date_delta(exp.DateAdd),
- "DATE_ADD": parse_date_delta(exp.DateAdd),
- "DATEDIFF": parse_date_delta(exp.DateDiff),
- "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
+ "DATEADD": build_date_delta(exp.DateAdd),
+ "DATE_ADD": build_date_delta(exp.DateAdd),
+ "DATEDIFF": build_date_delta(exp.DateDiff),
+ "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
}
FACTOR = {
@@ -38,20 +44,16 @@ class Databricks(Spark):
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DatetimeAdd: lambda self, e: self.func(
- "TIMESTAMPADD", e.text("unit"), e.expression, e.this
+ "TIMESTAMPADD", e.unit, e.expression, e.this
),
exp.DatetimeSub: lambda self, e: self.func(
"TIMESTAMPADD",
- e.text("unit"),
+ e.unit,
exp.Mul(this=e.expression, expression=exp.Literal.number(-1)),
e.this,
),
- exp.DatetimeDiff: lambda self, e: self.func(
- "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
- ),
- exp.TimestampDiff: lambda self, e: self.func(
- "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
- ),
+ exp.DatetimeDiff: _timestamp_diff,
+ exp.TimestampDiff: _timestamp_diff,
exp.DatetimeTrunc: timestamptrunc_sql,
exp.JSONExtract: lambda self, e: self.binary(e, ":"),
exp.Select: transforms.preprocess(
@@ -75,6 +77,7 @@ class Databricks(Spark):
):
# only BIGINT generated identity constraints are supported
expression.set("kind", exp.DataType.build("bigint"))
+
return super().columndef_sql(expression, sep)
def generatedasidentitycolumnconstraint_sql(
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 0440a99..b0a78d2 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -17,12 +17,12 @@ from sqlglot.trie import new_trie
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
+JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
+
if t.TYPE_CHECKING:
from sqlglot._typing import B, E, F
- JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
-
logger = logging.getLogger("sqlglot")
@@ -148,47 +148,53 @@ class _Dialect(type):
class Dialect(metaclass=_Dialect):
INDEX_OFFSET = 0
- """Determines the base index offset for arrays."""
+ """The base index offset for arrays."""
WEEK_OFFSET = 0
- """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
+ """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday."""
UNNEST_COLUMN_ONLY = False
- """Determines whether or not `UNNEST` table aliases are treated as column aliases."""
+ """Whether `UNNEST` table aliases are treated as column aliases."""
ALIAS_POST_TABLESAMPLE = False
- """Determines whether or not the table alias comes after tablesample."""
+ """Whether the table alias comes after tablesample."""
TABLESAMPLE_SIZE_IS_PERCENT = False
- """Determines whether or not a size in the table sample clause represents percentage."""
+ """Whether a size in the table sample clause represents percentage."""
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
"""Specifies the strategy according to which identifiers should be normalized."""
IDENTIFIERS_CAN_START_WITH_DIGIT = False
- """Determines whether or not an unquoted identifier can start with a digit."""
+ """Whether an unquoted identifier can start with a digit."""
DPIPE_IS_STRING_CONCAT = True
- """Determines whether or not the DPIPE token (`||`) is a string concatenation operator."""
+ """Whether the DPIPE token (`||`) is a string concatenation operator."""
STRICT_STRING_CONCAT = False
- """Determines whether or not `CONCAT`'s arguments must be strings."""
+ """Whether `CONCAT`'s arguments must be strings."""
SUPPORTS_USER_DEFINED_TYPES = True
- """Determines whether or not user-defined data types are supported."""
+ """Whether user-defined data types are supported."""
SUPPORTS_SEMI_ANTI_JOIN = True
- """Determines whether or not `SEMI` or `ANTI` joins are supported."""
+ """Whether `SEMI` or `ANTI` joins are supported."""
NORMALIZE_FUNCTIONS: bool | str = "upper"
- """Determines how function names are going to be normalized."""
+ """
+ Determines how function names are going to be normalized.
+ Possible values:
+ "upper" or True: Convert names to uppercase.
+ "lower": Convert names to lowercase.
+ False: Disables function name normalization.
+ """
LOG_BASE_FIRST = True
- """Determines whether the base comes first in the `LOG` function."""
+ """Whether the base comes first in the `LOG` function."""
NULL_ORDERING = "nulls_are_small"
"""
- Indicates the default `NULL` ordering method to use if not explicitly set.
+ Default `NULL` ordering method to use if not explicitly set.
Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"`
"""
@@ -200,7 +206,7 @@ class Dialect(metaclass=_Dialect):
"""
SAFE_DIVISION = False
- """Determines whether division by zero throws an error (`False`) or returns NULL (`True`)."""
+ """Whether division by zero throws an error (`False`) or returns NULL (`True`)."""
CONCAT_COALESCE = False
"""A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string."""
@@ -210,7 +216,7 @@ class Dialect(metaclass=_Dialect):
TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'"
TIME_MAPPING: t.Dict[str, str] = {}
- """Associates this dialect's time formats with their equivalent Python `strftime` format."""
+ """Associates this dialect's time formats with their equivalent Python `strftime` formats."""
# https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time
# https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE
@@ -418,7 +424,7 @@ class Dialect(metaclass=_Dialect):
`"safe"`: Only returns `True` if the identifier is case-insensitive.
Returns:
- Whether or not the given text can be identified.
+ Whether the given text can be identified.
"""
if identify is True or identify == "always":
return True
@@ -614,7 +620,7 @@ def var_map_sql(
return self.func(map_func_name, *args)
-def format_time_lambda(
+def build_formatted_time(
exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None
) -> t.Callable[[t.List], E]:
"""Helper used for time expressions.
@@ -628,7 +634,7 @@ def format_time_lambda(
A callable that can be used to return the appropriately formatted time expression.
"""
- def _format_time(args: t.List):
+ def _builder(args: t.List):
return exp_class(
this=seq_get(args, 0),
format=Dialect[dialect].format_time(
@@ -637,7 +643,7 @@ def format_time_lambda(
),
)
- return _format_time
+ return _builder
def time_format(
@@ -654,23 +660,23 @@ def time_format(
return _time_format
-def parse_date_delta(
+def build_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
- def inner_func(args: t.List) -> E:
+ def _builder(args: t.List) -> E:
unit_based = len(args) == 3
this = args[2] if unit_based else seq_get(args, 0)
unit = args[0] if unit_based else exp.Literal.string("DAY")
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit
return exp_class(this=this, expression=seq_get(args, 1), unit=unit)
- return inner_func
+ return _builder
-def parse_date_delta_with_interval(
+def build_date_delta_with_interval(
expression_class: t.Type[E],
) -> t.Callable[[t.List], t.Optional[E]]:
- def func(args: t.List) -> t.Optional[E]:
+ def _builder(args: t.List) -> t.Optional[E]:
if len(args) < 2:
return None
@@ -687,7 +693,7 @@ def parse_date_delta_with_interval(
this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
)
- return func
+ return _builder
def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
@@ -888,7 +894,7 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects
-def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
+def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc:
return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0))
@@ -991,10 +997,10 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
return self.merge_sql(expression)
-def parse_json_extract_path(
+def build_json_extract_path(
expr_type: t.Type[F], zero_based_indexing: bool = True
) -> t.Callable[[t.List], F]:
- def _parse_json_extract_path(args: t.List) -> F:
+ def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
for arg in args[1:]:
if not isinstance(arg, exp.Literal):
@@ -1014,11 +1020,11 @@ def parse_json_extract_path(
del args[2:]
return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
- return _parse_json_extract_path
+ return _builder
def json_extract_segments(
- name: str, quoted_index: bool = True
+ name: str, quoted_index: bool = True, op: t.Optional[str] = None
) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
path = expression.expression
@@ -1036,6 +1042,8 @@ def json_extract_segments(
segments.append(path)
+ if op:
+ return f" {op} ".join([self.sql(expression.this), *segments])
return self.func(name, expression.this, *segments)
return _json_extract_segments
@@ -1046,3 +1054,19 @@ def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str
self.unsupported("Unsupported wildcard in JSONPathKey expression")
return expression.name
+
+
+def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str:
+ cond = expression.expression
+ if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1:
+ alias = cond.expressions[0]
+ cond = cond.this
+ elif isinstance(cond, exp.Predicate):
+ alias = "_u"
+ else:
+ self.unsupported("Unsupported filter condition")
+ return ""
+
+ unnest = exp.Unnest(expressions=[expression.this])
+ filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
+ return self.sql(exp.Array(expressions=[filtered]))
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 7a18e8e..067a045 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -4,7 +4,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
- parse_timestamp_trunc,
+ build_timestamp_trunc,
rename_func,
time_format,
)
@@ -20,7 +20,7 @@ class Doris(MySQL):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
- "DATE_TRUNC": parse_timestamp_trunc,
+ "DATE_TRUNC": build_timestamp_trunc,
"REGEXP": exp.RegexpLike.from_arg_list,
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
@@ -46,7 +46,7 @@ class Doris(MySQL):
exp.ArgMin: rename_func("MIN_BY"),
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
- exp.CurrentTimestamp: lambda *_: "NOW()",
+ exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
@@ -55,14 +55,11 @@ class Doris(MySQL):
exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
- exp.StrToUnix: lambda self,
- e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
- exp.ToChar: lambda self,
- e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TsOrDsAdd: lambda self,
- e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level
+ exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
+ exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimestampTrunc: lambda self, e: self.func(
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 409e260..4e699f5 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
no_trycast_sql,
rename_func,
str_position_sql,
@@ -19,9 +19,7 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D
def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
this = self.sql(expression, "this")
unit = exp.var(expression.text("unit").upper() or "DAY")
- return (
- f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
- )
+ return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit))
return func
@@ -30,8 +28,8 @@ def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.DATE_FORMAT:
- return f"CAST({this} AS DATE)"
- return f"TO_DATE({this}, {time_format})"
+ return self.sql(exp.cast(this, "date"))
+ return self.func("TO_DATE", this, time_format)
class Drill(Dialect):
@@ -86,9 +84,9 @@ class Drill(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
+ "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
- "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
+ "TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"),
}
LOG_DEFAULTS_TO_LN = True
@@ -135,8 +133,7 @@ class Drill(Dialect):
e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})",
exp.If: lambda self,
e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})",
- exp.ILike: lambda self,
- e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
+ exp.ILike: lambda self, e: self.binary(e, "`ILIKE`"),
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
@@ -146,12 +143,11 @@ class Drill(Dialect):
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
+ exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
- exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index e61ac4f..925c5ae 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -14,7 +14,7 @@ from sqlglot.dialects.dialect import (
date_trunc_to_time,
datestrtodate_sql,
encode_decode_sql,
- format_time_lambda,
+ build_formatted_time,
inline_array_sql,
no_comment_column_constraint_sql,
no_safe_divide_sql,
@@ -62,26 +62,24 @@ def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str:
def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
- self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
- return f"ARRAY_SORT({self.sql(expression, 'this')})"
+ self.unsupported("DuckDB ARRAY_SORT does not support a comparator")
+ return self.func("ARRAY_SORT", expression.this)
def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str:
- this = self.sql(expression, "this")
- if expression.args.get("asc") == exp.false():
- return f"ARRAY_REVERSE_SORT({this})"
- return f"ARRAY_SORT({this})"
+ name = "ARRAY_REVERSE_SORT" if expression.args.get("asc") == exp.false() else "ARRAY_SORT"
+ return self.func(name, expression.this)
-def _sort_array_reverse(args: t.List) -> exp.Expression:
+def _build_sort_array_desc(args: t.List) -> exp.Expression:
return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
-def _parse_date_diff(args: t.List) -> exp.Expression:
+def _build_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-def _parse_make_timestamp(args: t.List) -> exp.Expression:
+def _build_make_timestamp(args: t.List) -> exp.Expression:
if len(args) == 1:
return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
@@ -103,10 +101,7 @@ def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str:
value = expr.this
else:
key = expr.name or expr.this.name
- if isinstance(expr, exp.Bracket):
- value = expr.expressions[0]
- else:
- value = expr.expression
+ value = expr.expression
args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}")
@@ -131,15 +126,16 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str:
def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
+ timestamp = expression.this
+
if scale in (None, exp.UnixToTime.SECONDS):
- return f"TO_TIMESTAMP({timestamp})"
+ return self.func("TO_TIMESTAMP", timestamp)
if scale == exp.UnixToTime.MILLIS:
- return f"EPOCH_MS({timestamp})"
+ return self.func("EPOCH_MS", timestamp)
if scale == exp.UnixToTime.MICROS:
- return f"MAKE_TIMESTAMP({timestamp})"
+ return self.func("MAKE_TIMESTAMP", timestamp)
- return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))"
+ return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)))
def _rename_unless_within_group(
@@ -152,7 +148,7 @@ def _rename_unless_within_group(
)
-def _parse_struct_pack(args: t.List) -> exp.Struct:
+def _build_struct_pack(args: t.List) -> exp.Struct:
args_with_columns_as_identifiers = [
exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
]
@@ -220,11 +216,10 @@ class DuckDB(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
- "ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
- "ARRAY_REVERSE_SORT": _sort_array_reverse,
- "DATEDIFF": _parse_date_diff,
- "DATE_DIFF": _parse_date_diff,
+ "ARRAY_REVERSE_SORT": _build_sort_array_desc,
+ "DATEDIFF": _build_date_diff,
+ "DATE_DIFF": _build_date_diff,
"DATE_TRUNC": date_trunc_to_time,
"DATETRUNC": date_trunc_to_time,
"DECODE": lambda args: exp.Decode(
@@ -238,14 +233,14 @@ class DuckDB(Dialect):
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
"JSON": exp.ParseJSON.from_arg_list,
- "JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract),
- "JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
+ "JSON_EXTRACT_PATH": parser.build_extract_json_with_path(exp.JSONExtract),
+ "JSON_EXTRACT_STRING": parser.build_extract_json_with_path(exp.JSONExtractScalar),
"LIST_HAS": exp.ArrayContains.from_arg_list,
- "LIST_REVERSE_SORT": _sort_array_reverse,
+ "LIST_REVERSE_SORT": _build_sort_array_desc,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
- "MAKE_TIMESTAMP": _parse_make_timestamp,
+ "MAKE_TIMESTAMP": _build_make_timestamp,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
@@ -261,12 +256,12 @@ class DuckDB(Dialect):
replacement=seq_get(args, 2),
modifiers=seq_get(args, 3),
),
- "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
+ "STRFTIME": build_formatted_time(exp.TimeToStr, "duckdb"),
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
- "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
- "STRUCT_PACK": _parse_struct_pack,
+ "STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"),
+ "STRUCT_PACK": _build_struct_pack,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
@@ -275,7 +270,7 @@ class DuckDB(Dialect):
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
- FUNCTION_PARSERS.pop("DECODE", None)
+ FUNCTION_PARSERS.pop("DECODE")
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
@@ -334,6 +329,7 @@ class DuckDB(Dialect):
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
MULTI_ARG_DISTINCT = False
+ CAN_IMPLEMENT_ARRAY_ANY = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -343,6 +339,7 @@ class DuckDB(Dialect):
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e)
),
+ exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),
@@ -350,9 +347,9 @@ class DuckDB(Dialect):
exp.ArraySum: rename_func("LIST_SUM"),
exp.BitwiseXor: rename_func("XOR"),
exp.CommentColumnConstraint: no_comment_column_constraint_sql,
- exp.CurrentDate: lambda self, e: "CURRENT_DATE",
- exp.CurrentTime: lambda self, e: "CURRENT_TIME",
- exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP",
+ exp.CurrentDate: lambda *_: "CURRENT_DATE",
+ exp.CurrentTime: lambda *_: "CURRENT_TIME",
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@@ -409,19 +406,19 @@ class DuckDB(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: str_to_time_sql,
- exp.StrToUnix: lambda self,
- e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.StrToUnix: lambda self, e: self.func(
+ "EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
+ ),
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
- exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
- exp.TimeToStr: lambda self,
- e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")),
+ exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
@@ -432,8 +429,9 @@ class DuckDB(Dialect):
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
- exp.UnixToStr: lambda self,
- e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
+ exp.UnixToStr: lambda self, e: self.func(
+ "STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e)
+ ),
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
exp.VariancePop: rename_func("VAR_POP"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index b1540bb..43211dc 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arg_max_or_min_no_count,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
if_sql,
is_parse_json,
left_to_substring_sql,
@@ -38,7 +38,6 @@ from sqlglot.transforms import (
move_schema_columns_to_partitioned_by,
)
from sqlglot.helper import seq_get
-from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
# (FuncType, Multiplier)
@@ -130,7 +129,7 @@ def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str:
def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
- return f"SORT_ARRAY({self.sql(expression, 'this')})"
+ return self.func("SORT_ARRAY", expression.this)
def _property_sql(self: Hive.Generator, expression: exp.Property) -> str:
@@ -157,23 +156,18 @@ def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str:
return f"CAST({this} AS TIMESTAMP)"
-def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str:
- this = self.sql(expression, "this")
- time_format = self.format_time(expression)
- return f"DATE_FORMAT({this}, {time_format})"
-
-
def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str:
- this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT):
- return f"TO_DATE({this}, {time_format})"
+ return self.func("TO_DATE", expression.this, time_format)
+
if isinstance(expression.this, exp.TsOrDsToDate):
- return this
- return f"TO_DATE({this})"
+ return self.sql(expression, "this")
+
+ return self.func("TO_DATE", expression.this)
-def _parse_ignore_nulls(
+def _build_with_ignore_nulls(
exp_class: t.Type[exp.Expression],
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
@@ -276,7 +270,7 @@ class Hive(Dialect):
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
- "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
+ "DATE_FORMAT": lambda args: build_formatted_time(exp.TimeToStr, "hive")(
[
exp.TimeStrToTime(this=seq_get(args, 0)),
seq_get(args, 1),
@@ -292,14 +286,14 @@ class Hive(Dialect):
expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "FIRST": _parse_ignore_nulls(exp.First),
- "FIRST_VALUE": _parse_ignore_nulls(exp.FirstValue),
- "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
+ "FIRST": _build_with_ignore_nulls(exp.First),
+ "FIRST_VALUE": _build_with_ignore_nulls(exp.FirstValue),
+ "FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
- "LAST": _parse_ignore_nulls(exp.Last),
- "LAST_VALUE": _parse_ignore_nulls(exp.LastValue),
+ "LAST": _build_with_ignore_nulls(exp.Last),
+ "LAST_VALUE": _build_with_ignore_nulls(exp.LastValue),
"LOCATE": locate_to_strposition,
- "MAP": parse_var_map,
+ "MAP": parser.build_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
@@ -313,10 +307,10 @@ class Hive(Dialect):
pair_delim=seq_get(args, 1) or exp.Literal.string(","),
key_value_delim=seq_get(args, 2) or exp.Literal.string(":"),
),
- "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
+ "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
- "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
+ "UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@@ -487,8 +481,10 @@ class Hive(Dialect):
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.IsNan: rename_func("ISNAN"),
- exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
- exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
+ exp.JSONExtract: lambda self, e: self.func("GET_JSON_OBJECT", e.this, e.expression),
+ exp.JSONExtractScalar: lambda self, e: self.func(
+ "GET_JSON_OBJECT", e.this, e.expression
+ ),
exp.JSONFormat: _json_format_sql,
exp.Left: left_to_substring_sql,
exp.Map: var_map_sql,
@@ -496,7 +492,7 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
- exp.NotNullColumnConstraint: lambda self, e: (
+ exp.NotNullColumnConstraint: lambda _, e: (
"" if e.args.get("allow_null") else "NOT NULL"
),
exp.VarMap: var_map_sql,
@@ -517,8 +513,9 @@ class Hive(Dialect):
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
- exp.Split: lambda self,
- e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
+ exp.Split: lambda self, e: self.func(
+ "SPLIT", e.this, self.func("CONCAT", "'\\\\Q'", e.expression)
+ ),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_time_sql,
@@ -527,7 +524,7 @@ class Hive(Dialect):
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
- exp.TimeToStr: _time_to_str,
+ exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
exp.ToBase64: rename_func("BASE64"),
exp.TsOrDiToDi: lambda self,
@@ -549,9 +546,9 @@ class Hive(Dialect):
e: f"({self.expressions(e, 'this', indent=False)})",
exp.NonClusteredColumnConstraint: lambda self,
e: f"({self.expressions(e, 'this', indent=False)})",
- exp.NotForReplicationColumnConstraint: lambda self, e: "",
- exp.OnProperty: lambda self, e: "",
- exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY",
+ exp.NotForReplicationColumnConstraint: lambda *_: "",
+ exp.OnProperty: lambda *_: "",
+ exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY",
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 97c891d..e549f62 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -9,7 +9,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
date_add_interval_sql,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
isnull_to_is_null,
locate_to_strposition,
max_or_greatest,
@@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
- parse_date_delta,
- parse_date_delta_with_interval,
+ build_date_delta,
+ build_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
)
@@ -39,9 +39,6 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit").upper()
- if unit == "DAY":
- return f"DATE({expr})"
-
if unit == "WEEK":
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
@@ -55,10 +52,11 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
- self.unsupported(f"Unexpected interval unit: {unit}")
- return f"DATE({expr})"
+ if unit != "DAY":
+ self.unsupported(f"Unexpected interval unit: {unit}")
+ return self.func("DATE", expr)
- return f"STR_TO_DATE({concat}, '{date_format}')"
+ return self.func("STR_TO_DATE", concat, f"'{date_format}'")
# All specifiers for time parts (as opposed to date parts)
@@ -93,8 +91,7 @@ def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime:
def _str_to_date_sql(
self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
- date_format = self.format_time(expression)
- return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
+ return self.func("STR_TO_DATE", expression.this, self.format_time(expression))
def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
@@ -127,9 +124,7 @@ def _date_add_sql(
def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = expression.args.get("format")
- if time_format:
- return _str_to_date_sql(self, expression)
- return f"DATE({self.sql(expression, 'this')})"
+ return _str_to_date_sql(self, expression) if time_format else self.func("DATE", expression.this)
def _remove_ts_or_ds_to_date(
@@ -289,9 +284,9 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
- "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
- "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
- "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DATE_ADD": build_date_delta_with_interval(exp.DateAdd),
+ "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"),
+ "DATE_SUB": build_date_delta_with_interval(exp.DateSub),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
@@ -306,7 +301,7 @@ class MySQL(Dialect):
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
- "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
+ "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
"TO_DAYS": lambda args: exp.paren(
exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index de693b9..fcb3aab 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
- format_time_lambda,
+ build_formatted_time,
no_ilike_sql,
rename_func,
trim_sql,
@@ -18,26 +18,7 @@ if t.TYPE_CHECKING:
from sqlglot._typing import E
-def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable:
- this = self._parse_string()
-
- passing = None
- columns = None
-
- if self._match_text_seq("PASSING"):
- # The BY VALUE keywords are optional and are provided for semantic clarity
- self._match_text_seq("BY", "VALUE")
- passing = self._parse_csv(self._parse_column)
-
- by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
-
- if self._match_text_seq("COLUMNS"):
- columns = self._parse_csv(self._parse_field_def)
-
- return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref)
-
-
-def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
+def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
this = seq_get(args, 0)
if this and not this.type:
@@ -45,7 +26,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
annotate_types(this)
if this.is_type(*exp.DataType.TEMPORAL_TYPES):
- return format_time_lambda(exp.TimeToStr, "oracle", default=True)(args)
+ return build_formatted_time(exp.TimeToStr, "oracle", default=True)(args)
return exp.ToChar.from_arg_list(args)
@@ -93,9 +74,9 @@ class Oracle(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
- "TO_CHAR": to_char,
- "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"),
- "TO_DATE": format_time_lambda(exp.StrToDate, "oracle"),
+ "TO_CHAR": _build_timetostr_or_tochar,
+ "TO_TIMESTAMP": build_formatted_time(exp.StrToTime, "oracle"),
+ "TO_DATE": build_formatted_time(exp.StrToDate, "oracle"),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -109,7 +90,7 @@ class Oracle(Dialect):
this=self._parse_format_json(self._parse_bitwise()),
order=self._parse_order(),
),
- "XMLTABLE": _parse_xml_table,
+ "XMLTABLE": lambda self: self._parse_xml_table(),
}
QUERY_MODIFIER_PARSERS = {
@@ -127,6 +108,26 @@ class Oracle(Dialect):
# Reference: https://stackoverflow.com/a/336455
DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE}
+ def _parse_xml_table(self) -> exp.XMLTable:
+ this = self._parse_string()
+
+ passing = None
+ columns = None
+
+ if self._match_text_seq("PASSING"):
+ # The BY VALUE keywords are optional and are provided for semantic clarity
+ self._match_text_seq("BY", "VALUE")
+ passing = self._parse_csv(self._parse_column)
+
+ by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF")
+
+ if self._match_text_seq("COLUMNS"):
+ columns = self._parse_csv(self._parse_field_def)
+
+ return self.expression(
+ exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref
+ )
+
def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E:
return self.expression(
expr_type,
@@ -200,18 +201,17 @@ class Oracle(Dialect):
transforms.eliminate_qualify,
]
),
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
+ exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "),
- exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self,
- e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
+ e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 126261e..c78f8a3 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -6,10 +6,12 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
DATE_ADD_OR_SUB,
Dialect,
+ JSON_EXTRACT_TYPE,
any_value_to_max_sql,
bool_xor_sql,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
+ filter_array_using_unnest,
json_extract_segments,
json_path_key_only_name,
max_or_greatest,
@@ -20,8 +22,8 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_pivot_sql,
no_trycast_sql,
- parse_json_extract_path,
- parse_timestamp_trunc,
+ build_json_extract_path,
+ build_timestamp_trunc,
rename_func,
str_position_sql,
struct_extract_sql,
@@ -163,7 +165,7 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
return expression
-def _generate_series(args: t.List) -> exp.Expression:
+def _build_generate_series(args: t.List) -> exp.GenerateSeries:
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
@@ -179,14 +181,25 @@ def _generate_series(args: t.List) -> exp.Expression:
return exp.GenerateSeries.from_arg_list(args)
-def _to_timestamp(args: t.List) -> exp.Expression:
+def _build_to_timestamp(args: t.List) -> exp.UnixToTime | exp.StrToTime:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
- return format_time_lambda(exp.StrToTime, "postgres")(args)
+ return build_formatted_time(exp.StrToTime, "postgres")(args)
+
+
+def _json_extract_sql(
+ name: str, op: str
+) -> t.Callable[[Postgres.Generator, JSON_EXTRACT_TYPE], str]:
+ def _generate(self: Postgres.Generator, expression: JSON_EXTRACT_TYPE) -> str:
+ if expression.args.get("only_json_types"):
+ return json_extract_segments(name, quoted_index=False, op=op)(self, expression)
+ return json_extract_segments(name)(self, expression)
+
+ return _generate
class Postgres(Dialect):
@@ -292,19 +305,19 @@ class Postgres(Dialect):
**parser.Parser.PROPERTY_PARSERS,
"SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()),
}
- PROPERTY_PARSERS.pop("INPUT", None)
+ PROPERTY_PARSERS.pop("INPUT")
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE_TRUNC": parse_timestamp_trunc,
- "GENERATE_SERIES": _generate_series,
- "JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract),
- "JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar),
+ "DATE_TRUNC": build_timestamp_trunc,
+ "GENERATE_SERIES": _build_generate_series,
+ "JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
+ "JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
- "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
- "TO_TIMESTAMP": _to_timestamp,
+ "TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
+ "TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
}
@@ -338,6 +351,8 @@ class Postgres(Dialect):
TokenType.END: lambda self: self._parse_commit_or_rollback(),
}
+ JSON_ARROWS_REQUIRE_JSON_TYPE = True
+
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
while True:
if not self._match(TokenType.L_PAREN):
@@ -387,6 +402,7 @@ class Postgres(Dialect):
SUPPORTS_UNLOGGED_TABLES = True
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
+ CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
@@ -416,6 +432,8 @@ class Postgres(Dialect):
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
+ exp.ArrayFilter: filter_array_using_unnest,
+ exp.ArraySize: lambda self, e: self.func("ARRAY_LENGTH", e.this, e.expression or "1"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.CurrentDate: no_paren_current_date_sql,
@@ -428,8 +446,8 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
- exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"),
- exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
+ exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
+ exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
@@ -462,21 +480,20 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("+"),
exp.TsOrDsDiff: _date_diff_sql,
- exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
+ exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this),
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 1e0e7e9..8429547 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -11,7 +11,7 @@ from sqlglot.dialects.dialect import (
date_trunc_to_time,
datestrtodate_sql,
encode_decode_sql,
- format_time_lambda,
+ build_formatted_time,
if_sql,
left_to_substring_sql,
no_ilike_sql,
@@ -31,12 +31,6 @@ from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
-def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str:
- accuracy = expression.args.get("accuracy")
- accuracy = ", " + self.sql(accuracy) if accuracy else ""
- return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-
-
def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str:
if isinstance(expression.this, exp.Explode):
return self.sql(
@@ -81,20 +75,20 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str:
self.unsupported("Presto does not support exact quantiles")
- return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
+ return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile"))
def _str_to_time_sql(
self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
- return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
+ return self.func("DATE_PARSE", expression.this, self.format_time(expression))
def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
- return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto")
- return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto")
+ return self.sql(exp.cast(_str_to_time_sql(self, expression), "DATE"))
+ return self.sql(exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE"))
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
@@ -110,7 +104,7 @@ def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> st
return self.func("DATE_DIFF", unit, expr, this)
-def _approx_percentile(args: t.List) -> exp.Expression:
+def _build_approx_percentile(args: t.List) -> exp.Expression:
if len(args) == 4:
return exp.ApproxQuantile(
this=seq_get(args, 0),
@@ -125,7 +119,7 @@ def _approx_percentile(args: t.List) -> exp.Expression:
return exp.ApproxQuantile.from_arg_list(args)
-def _from_unixtime(args: t.List) -> exp.Expression:
+def _build_from_unixtime(args: t.List) -> exp.Expression:
if len(args) == 3:
return exp.UnixToTime(
this=seq_get(args, 0),
@@ -182,7 +176,7 @@ def _to_int(expression: exp.Expression) -> exp.Expression:
return expression
-def _parse_to_char(args: t.List) -> exp.TimeToStr:
+def _build_to_char(args: t.List) -> exp.TimeToStr:
fmt = seq_get(args, 1)
if isinstance(fmt, exp.Literal):
# We uppercase this to match Teradata's format mapping keys
@@ -190,7 +184,7 @@ def _parse_to_char(args: t.List) -> exp.TimeToStr:
# We use "teradata" on purpose here, because the time formats are different in Presto.
# See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char
- return format_time_lambda(exp.TimeToStr, "teradata")(args)
+ return build_formatted_time(exp.TimeToStr, "teradata")(args)
class Presto(Dialect):
@@ -231,7 +225,7 @@ class Presto(Dialect):
**parser.Parser.FUNCTIONS,
"ARBITRARY": exp.AnyValue.from_arg_list,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
- "APPROX_PERCENTILE": _approx_percentile,
+ "APPROX_PERCENTILE": _build_approx_percentile,
"BITWISE_AND": binary_from_function(exp.BitwiseAnd),
"BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)),
"BITWISE_OR": binary_from_function(exp.BitwiseOr),
@@ -244,14 +238,14 @@ class Presto(Dialect):
"DATE_DIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
- "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
- "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
+ "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "presto"),
+ "DATE_PARSE": build_formatted_time(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"ELEMENT_AT": lambda args: exp.Bracket(
this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True
),
"FROM_HEX": exp.Unhex.from_arg_list,
- "FROM_UNIXTIME": _from_unixtime,
+ "FROM_UNIXTIME": _build_from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
@@ -271,7 +265,7 @@ class Presto(Dialect):
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
- "TO_CHAR": _parse_to_char,
+ "TO_CHAR": _build_to_char,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
@@ -318,35 +312,35 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: rename_func("ARBITRARY"),
- exp.ApproxDistinct: _approx_distinct_sql,
+ exp.ApproxDistinct: lambda self, e: self.func(
+ "APPROX_DISTINCT", e.this, e.args.get("accuracy")
+ ),
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
+ exp.ArrayAny: rename_func("ANY_MATCH"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
- exp.BitwiseAnd: lambda self,
- e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseLeftShift: lambda self,
- e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
- exp.BitwiseOr: lambda self,
- e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseRightShift: lambda self,
- e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.BitwiseXor: lambda self,
- e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression),
+ exp.BitwiseLeftShift: lambda self, e: self.func(
+ "BITWISE_ARITHMETIC_SHIFT_LEFT", e.this, e.expression
+ ),
+ exp.BitwiseNot: lambda self, e: self.func("BITWISE_NOT", e.this),
+ exp.BitwiseOr: lambda self, e: self.func("BITWISE_OR", e.this, e.expression),
+ exp.BitwiseRightShift: lambda self, e: self.func(
+ "BITWISE_ARITHMETIC_SHIFT_RIGHT", e.this, e.expression
+ ),
+ exp.BitwiseXor: lambda self, e: self.func("BITWISE_XOR", e.this, e.expression),
exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
exp.Literal.string(e.text("unit") or "DAY"),
- _to_int(
- e.expression,
- ),
+ _to_int(e.expression),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
@@ -407,21 +401,21 @@ class Presto(Dialect):
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
- exp.StrToUnix: lambda self,
- e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.StrToUnix: lambda self, e: self.func(
+ "TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e))
+ ),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeStrToUnix: lambda self,
- e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))",
- exp.TimeToStr: lambda self,
- e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToUnix: lambda self, e: self.func(
+ "TO_UNIXTIME", self.func("DATE_PARSE", e.this, Presto.TIME_FORMAT)
+ ),
+ exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
- exp.ToChar: lambda self,
- e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 135ffc6..2201c78 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -21,15 +21,15 @@ if t.TYPE_CHECKING:
from sqlglot._typing import E
-def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
- def _parse_delta(args: t.List) -> E:
+def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
+ def _builder(args: t.List) -> E:
expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
if expr_type is exp.TsOrDsAdd:
expr.set("return_type", exp.DataType.build("TIMESTAMP"))
return expr
- return _parse_delta
+ return _builder
class Redshift(Postgres):
@@ -55,10 +55,10 @@ class Redshift(Postgres):
unit=exp.var("month"),
return_type=exp.DataType.build("TIMESTAMP"),
),
- "DATEADD": _parse_date_delta(exp.TsOrDsAdd),
- "DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
- "DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
- "DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
+ "DATEADD": _build_date_delta(exp.TsOrDsAdd),
+ "DATE_ADD": _build_date_delta(exp.TsOrDsAdd),
+ "DATEDIFF": _build_date_delta(exp.TsOrDsDiff),
+ "DATE_DIFF": _build_date_delta(exp.TsOrDsDiff),
"GETDATE": exp.CurrentTimestamp.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
@@ -171,6 +171,7 @@ class Redshift(Postgres):
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = True
LAST_DAY_SUPPORTS_DATE_PART = False
+ CAN_IMPLEMENT_ARRAY_ANY = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -192,11 +193,12 @@ class Redshift(Postgres):
),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
- exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
+ exp.DistKeyProperty: lambda self, e: self.func("DISTKEY", e.this),
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
+ exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index b4275ea..c773e50 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
if_sql,
inline_array_sql,
max_or_greatest,
@@ -29,12 +29,12 @@ if t.TYPE_CHECKING:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
+def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
- return format_time_lambda(exp.StrToTime, "snowflake")(args)
+ return build_formatted_time(exp.StrToTime, "snowflake")(args)
return exp.UnixToTime(this=first_arg, scale=second_arg)
from sqlglot.optimizer.simplify import simplify_literals
@@ -52,14 +52,14 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
return exp.UnixToTime.from_arg_list(args)
# case: <date_expr>
- return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
+ return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
# case: <numeric_expr>
return exp.UnixToTime.from_arg_list(args)
-def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
- expression = parser.parse_var_map(args)
+def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
+ expression = parser.build_var_map(args)
if isinstance(expression, exp.StarMap):
return expression
@@ -71,48 +71,14 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
)
-def _parse_datediff(args: t.List) -> exp.DateDiff:
+def _build_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
)
-# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
-# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
- this = self._parse_var() or self._parse_type()
-
- if not this:
- return None
-
- self._match(TokenType.COMMA)
- expression = self._parse_bitwise()
- this = _map_date_part(this)
- name = this.name.upper()
-
- if name.startswith("EPOCH"):
- if name == "EPOCH_MILLISECOND":
- scale = 10**3
- elif name == "EPOCH_MICROSECOND":
- scale = 10**6
- elif name == "EPOCH_NANOSECOND":
- scale = 10**9
- else:
- scale = None
-
- ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
- to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
-
- if scale:
- to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
-
- return to_unix
-
- return self.expression(exp.Extract, this=this, expression=expression)
-
-
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args: t.List) -> exp.If:
+def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -120,13 +86,13 @@ def _div0_to_if(args: t.List) -> exp.If:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args: t.List) -> exp.If:
+def _build_if_from_zeroifnull(args: t.List) -> exp.If:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args: t.List) -> exp.If:
+def _build_if_from_nullifzero(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@@ -150,13 +116,13 @@ def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) ->
)
-def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
+def _build_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
-def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
+def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
regexp_replace = exp.RegexpReplace.from_arg_list(args)
if not regexp_replace.args.get("replacement"):
@@ -266,38 +232,7 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
return trunc
-def _parse_colon_get_path(
- self: parser.Parser, this: t.Optional[exp.Expression]
-) -> t.Optional[exp.Expression]:
- while True:
- path = self._parse_bitwise()
-
- # The cast :: operator has a lower precedence than the extraction operator :, so
- # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
- if isinstance(path, exp.Cast):
- target_type = path.to
- path = path.this
- else:
- target_type = None
-
- if isinstance(path, exp.Expression):
- path = exp.Literal.string(path.sql(dialect="snowflake"))
-
- # The extraction operator : is left-associative
- this = self.expression(
- exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
- )
-
- if target_type:
- this = exp.cast(this, target_type)
-
- if not self._match(TokenType.COLON):
- break
-
- return self._parse_range(this)
-
-
-def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
+def _build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
@@ -396,15 +331,15 @@ class Snowflake(Dialect):
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
- "CONVERT_TIMEZONE": _parse_convert_timezone,
+ "CONVERT_TIMEZONE": _build_convert_timezone,
"DATE_TRUNC": _date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
),
- "DATEDIFF": _parse_datediff,
- "DIV0": _div0_to_if,
+ "DATEDIFF": _build_datediff,
+ "DIV0": _build_if_from_div0,
"FLATTEN": exp.Explode.from_arg_list,
"GET_PATH": lambda args, dialect: exp.JSONExtract(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
@@ -414,24 +349,24 @@ class Snowflake(Dialect):
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
- "NULLIFZERO": _nullifzero_to_if,
- "OBJECT_CONSTRUCT": _parse_object_construct,
- "REGEXP_REPLACE": _parse_regexp_replace,
+ "NULLIFZERO": _build_if_from_nullifzero,
+ "OBJECT_CONSTRUCT": _build_object_construct,
+ "REGEXP_REPLACE": _build_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
- "TIMEDIFF": _parse_datediff,
- "TIMESTAMPDIFF": _parse_datediff,
- "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
- "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
- "TO_TIMESTAMP": _parse_to_timestamp,
+ "TIMEDIFF": _build_datediff,
+ "TIMESTAMPDIFF": _build_datediff,
+ "TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
+ "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
+ "TO_TIMESTAMP": _build_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
- "ZEROIFNULL": _zeroifnull_to_if,
+ "ZEROIFNULL": _build_if_from_zeroifnull,
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
- "DATE_PART": _parse_date_part,
+ "DATE_PART": lambda self: self._parse_date_part(),
"OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
}
FUNCTION_PARSERS.pop("TRIM")
@@ -442,7 +377,7 @@ class Snowflake(Dialect):
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
- TokenType.COLON: _parse_colon_get_path,
+ TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
}
ALTER_PARSERS = {
@@ -489,6 +424,69 @@ class Snowflake(Dialect):
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
+ def _parse_colon_get_path(
+ self: parser.Parser, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ while True:
+ path = self._parse_bitwise()
+
+ # The cast :: operator has a lower precedence than the extraction operator :, so
+ # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
+ if isinstance(path, exp.Cast):
+ target_type = path.to
+ path = path.this
+ else:
+ target_type = None
+
+ if isinstance(path, exp.Expression):
+ path = exp.Literal.string(path.sql(dialect="snowflake"))
+
+ # The extraction operator : is left-associative
+ this = self.expression(
+ exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
+ )
+
+ if target_type:
+ this = exp.cast(this, target_type)
+
+ if not self._match(TokenType.COLON):
+ break
+
+ return self._parse_range(this)
+
+ # https://docs.snowflake.com/en/sql-reference/functions/date_part.html
+ # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
+ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
+ this = self._parse_var() or self._parse_type()
+
+ if not this:
+ return None
+
+ self._match(TokenType.COMMA)
+ expression = self._parse_bitwise()
+ this = _map_date_part(this)
+ name = this.name.upper()
+
+ if name.startswith("EPOCH"):
+ if name == "EPOCH_MILLISECOND":
+ scale = 10**3
+ elif name == "EPOCH_MICROSECOND":
+ scale = 10**6
+ elif name == "EPOCH_NANOSECOND":
+ scale = 10**9
+ else:
+ scale = None
+
+ ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
+ to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
+
+ if scale:
+ to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
+
+ return to_unix
+
+ return self.expression(exp.Extract, this=this, expression=expression)
+
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
if is_map:
# Keys are strings in Snowflake's objects, see also:
@@ -665,6 +663,7 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
"SQL_DOUBLE": TokenType.DOUBLE,
"SQL_VARCHAR": TokenType.VARCHAR,
+ "STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -724,8 +723,10 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
- exp.JSONExtract: rename_func("GET_PATH"),
- exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"),
+ exp.JSONExtract: lambda self, e: self.func("GET_PATH", e.this, e.expression),
+ exp.JSONExtractScalar: lambda self, e: self.func(
+ "JSON_EXTRACT_PATH_TEXT", e.this, e.expression
+ ),
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
exp.JSONPathRoot: lambda *_: "",
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
@@ -756,8 +757,7 @@ class Snowflake(Dialect):
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Struct: lambda self, e: self.func(
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
@@ -901,12 +901,12 @@ class Snowflake(Dialect):
)
def except_op(self, expression: exp.Except) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
def intersect_op(self, expression: exp.Intersect) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index c662ab5..20c0fce 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -4,7 +4,7 @@ import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
-from sqlglot.dialects.hive import _parse_ignore_nulls
+from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
from sqlglot.transforms import (
@@ -15,7 +15,7 @@ from sqlglot.transforms import (
)
-def _parse_datediff(args: t.List) -> exp.Expression:
+def _build_datediff(args: t.List) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
it at some point. Databricks also supports this variant (see below).
@@ -61,8 +61,8 @@ class Spark(Spark2):
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
- "ANY_VALUE": _parse_ignore_nulls(exp.AnyValue),
- "DATEDIFF": _parse_datediff,
+ "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
+ "DATEDIFF": _build_datediff,
}
def _parse_generated_as_identity(
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index fa55b51..60cf8e1 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
- format_time_lambda,
+ build_formatted_time,
is_parse_json,
pivot_column_names,
rename_func,
@@ -26,36 +26,37 @@ def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
values = expression.args.get("values")
if not keys or not values:
- return "MAP()"
+ return self.func("MAP")
- return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})"
+ return self.func("MAP_FROM_ARRAYS", keys, values)
-def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
+def _build_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]:
return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type))
def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str:
- this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Hive.DATE_FORMAT:
- return f"TO_DATE({this})"
- return f"TO_DATE({this}, {time_format})"
+ return self.func("TO_DATE", expression.this)
+ return self.func("TO_DATE", expression.this, time_format)
def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
+ timestamp = expression.this
+
if scale is None:
- return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)"
+ return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp"))
if scale == exp.UnixToTime.SECONDS:
- return f"TIMESTAMP_SECONDS({timestamp})"
+ return self.func("TIMESTAMP_SECONDS", timestamp)
if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
+ return self.func("TIMESTAMP_MILLIS", timestamp)
if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
+ return self.func("TIMESTAMP_MICROS", timestamp)
- return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
+ unix_seconds = exp.Div(this=timestamp, expression=exp.func("POW", 10, scale))
+ return self.func("TIMESTAMP_SECONDS", unix_seconds)
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
@@ -116,16 +117,16 @@ class Spark2(Hive):
**Hive.Parser.FUNCTIONS,
"AGGREGATE": exp.Reduce.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "BOOLEAN": _parse_as_cast("boolean"),
- "DATE": _parse_as_cast("date"),
+ "BOOLEAN": _build_as_cast("boolean"),
+ "DATE": _build_as_cast("date"),
"DATE_TRUNC": lambda args: exp.TimestampTrunc(
this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DOUBLE": _parse_as_cast("double"),
- "FLOAT": _parse_as_cast("float"),
+ "DOUBLE": _build_as_cast("double"),
+ "FLOAT": _build_as_cast("float"),
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
this=exp.cast_unless(
seq_get(args, 0) or exp.Var(this=""),
@@ -134,17 +135,17 @@ class Spark2(Hive):
),
zone=seq_get(args, 1),
),
- "INT": _parse_as_cast("int"),
+ "INT": _build_as_cast("int"),
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
- "STRING": _parse_as_cast("string"),
- "TIMESTAMP": _parse_as_cast("timestamp"),
+ "STRING": _build_as_cast("string"),
+ "TIMESTAMP": _build_as_cast("timestamp"),
"TO_TIMESTAMP": lambda args: (
- _parse_as_cast("timestamp")(args)
+ _build_as_cast("timestamp")(args)
if len(args) == 1
- else format_time_lambda(exp.StrToTime, "spark")(args)
+ else build_formatted_time(exp.StrToTime, "spark")(args)
),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
@@ -187,6 +188,7 @@ class Spark2(Hive):
class Generator(Hive.Generator):
QUERY_HINTS = True
NVL2_SUPPORTED = True
+ CAN_IMPLEMENT_ARRAY_ANY = True
PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
@@ -201,8 +203,9 @@ class Spark2(Hive):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
- exp.AtTimeZone: lambda self,
- e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.AtTimeZone: lambda self, e: self.func(
+ "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
+ ),
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
exp.Create: preprocess(
@@ -221,8 +224,9 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
- exp.FromTimeZone: lambda self,
- e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
+ exp.FromTimeZone: lambda self, e: self.func(
+ "TO_UTC_TIMESTAMP", e.this, e.args.get("zone")
+ ),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@@ -236,8 +240,7 @@ class Spark2(Hive):
e.args.get("position"),
),
exp.StrToDate: _str_to_date,
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
@@ -263,10 +266,7 @@ class Spark2(Hive):
args = []
for arg in expression.expressions:
if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
- if isinstance(arg, exp.Bracket):
- args.append(exp.alias_(arg.this, arg.expressions[0].name))
- else:
- args.append(exp.alias_(arg.expression, arg.this.name))
+ args.append(exp.alias_(arg.expression, arg.this.name))
else:
args.append(arg)
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 8838f34..12ac600 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -4,7 +4,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
- parse_timestamp_trunc,
+ build_timestamp_trunc,
rename_func,
)
from sqlglot.dialects.mysql import MySQL
@@ -15,7 +15,7 @@ class StarRocks(MySQL):
class Parser(MySQL.Parser):
FUNCTIONS = {
**MySQL.Parser.FUNCTIONS,
- "DATE_TRUNC": parse_timestamp_trunc,
+ "DATE_TRUNC": build_timestamp_trunc,
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
@@ -44,14 +44,12 @@ class StarRocks(MySQL):
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
- exp.StrToUnix: lambda self,
- e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func(
"DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
),
exp.TimeStrToDate: rename_func("TO_DATE"),
- exp.UnixToStr: lambda self,
- e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index e8ff249..b736918 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -34,8 +34,8 @@ class Tableau(Dialect):
def count_sql(self, expression: exp.Count) -> str:
this = expression.this
if isinstance(this, exp.Distinct):
- return f"COUNTD({self.expressions(this, flat=True)})"
- return f"COUNT({self.sql(expression, 'this')})"
+ return self.func("COUNTD", *this.expressions)
+ return self.func("COUNT", this)
class Parser(parser.Parser):
FUNCTIONS = {
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 5b30cd4..0663a1d 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -178,6 +178,7 @@ class Teradata(Dialect):
QUERY_HINTS = False
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
+ CAN_IMPLEMENT_ARRAY_ANY = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -195,6 +196,7 @@ class Teradata(Dialect):
**generator.Generator.TRANSFORMS,
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
+ exp.ArraySize: rename_func("CARDINALITY"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Pow: lambda self, e: self.binary(e, "**"),
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 85b2e12..5955352 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
- parse_date_delta,
+ build_date_delta,
rename_func,
timestrtotime_sql,
trim_sql,
@@ -64,10 +64,10 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1)
BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias}
-def _format_time_lambda(
+def _build_formatted_time(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
) -> t.Callable[[t.List], E]:
- def _format_time(args: t.List) -> E:
+ def _builder(args: t.List) -> E:
assert len(args) == 2
return exp_class(
@@ -84,10 +84,10 @@ def _format_time_lambda(
),
)
- return _format_time
+ return _builder
-def _parse_format(args: t.List) -> exp.Expression:
+def _build_format(args: t.List) -> exp.NumberToStr | exp.TimeToStr:
this = seq_get(args, 0)
fmt = seq_get(args, 1)
culture = seq_get(args, 2)
@@ -107,7 +107,7 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(this=this, format=fmt, culture=culture)
-def _parse_eomonth(args: t.List) -> exp.LastDay:
+def _build_eomonth(args: t.List) -> exp.LastDay:
date = exp.TsOrDsToDate(this=seq_get(args, 0))
month_lag = seq_get(args, 1)
@@ -120,7 +120,7 @@ def _parse_eomonth(args: t.List) -> exp.LastDay:
return exp.LastDay(this=this)
-def _parse_hashbytes(args: t.List) -> exp.Expression:
+def _build_hashbytes(args: t.List) -> exp.Expression:
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
@@ -179,10 +179,10 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
return f"STRING_AGG({self.format_args(this, separator)}){order}"
-def _parse_date_delta(
+def _build_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
- def inner_func(args: t.List) -> E:
+ def _builder(args: t.List) -> E:
unit = seq_get(args, 0)
if unit and unit_mapping:
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name))
@@ -204,7 +204,7 @@ def _parse_date_delta(
unit=unit,
)
- return inner_func
+ return _builder
def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
@@ -242,7 +242,7 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
-def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
+def _build_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
return exp.TimestampFromParts(
year=seq_get(args, 0),
month=seq_get(args, 1),
@@ -255,7 +255,7 @@ def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
-def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
+def _build_timefromparts(args: t.List) -> exp.TimeFromParts:
return exp.TimeFromParts(
hour=seq_get(args, 0),
min=seq_get(args, 1),
@@ -265,7 +265,7 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
-def _parse_as_text(
+def _build_with_arg_as_text(
klass: t.Type[exp.Expression],
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
@@ -288,8 +288,8 @@ def _parse_as_text(
def _json_extract_sql(
self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
- json_query = rename_func("JSON_QUERY")(self, expression)
- json_value = rename_func("JSON_VALUE")(self, expression)
+ json_query = self.func("JSON_QUERY", expression.this, expression.expression)
+ json_value = self.func("JSON_VALUE", expression.this, expression.expression)
return self.func("ISNULL", json_query, json_value)
@@ -448,28 +448,28 @@ class TSQL(Dialect):
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
- "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
- "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
- "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
- "DATEPART": _format_time_lambda(exp.TimeToStr),
- "DATETIMEFROMPARTS": _parse_datetimefromparts,
- "EOMONTH": _parse_eomonth,
- "FORMAT": _parse_format,
+ "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
+ "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
+ "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True),
+ "DATEPART": _build_formatted_time(exp.TimeToStr),
+ "DATETIMEFROMPARTS": _build_datetimefromparts,
+ "EOMONTH": _build_eomonth,
+ "FORMAT": _build_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
- "HASHBYTES": _parse_hashbytes,
+ "HASHBYTES": _build_hashbytes,
"ISNULL": exp.Coalesce.from_arg_list,
- "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
- "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
- "LEN": _parse_as_text(exp.Length),
- "LEFT": _parse_as_text(exp.Left),
- "RIGHT": _parse_as_text(exp.Right),
+ "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract),
+ "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar),
+ "LEN": _build_with_arg_as_text(exp.Length),
+ "LEFT": _build_with_arg_as_text(exp.Left),
+ "RIGHT": _build_with_arg_as_text(exp.Right),
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
- "TIMEFROMPARTS": _parse_timefromparts,
+ "TIMEFROMPARTS": _build_timefromparts,
}
JOIN_HINTS = {
@@ -756,6 +756,9 @@ class TSQL(Dialect):
transforms.eliminate_qualify,
]
),
+ exp.StrPosition: lambda self, e: self.func(
+ "CHARINDEX", e.args.get("substr"), e.this, e.args.get("position")
+ ),
exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
@@ -855,7 +858,7 @@ class TSQL(Dialect):
return sql
def create_sql(self, expression: exp.Create) -> str:
- kind = self.sql(expression, "kind").upper()
+ kind = expression.kind
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 8ef750e..1408d3c 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -67,8 +67,8 @@ class Expression(metaclass=_Expression):
Attributes:
key: a unique key for each class in the Expression hierarchy. This is useful for hashing
and representing expressions as strings.
- arg_types: determines what arguments (child nodes) are supported by an expression. It
- maps arg keys to booleans that indicate whether the corresponding args are optional.
+ arg_types: determines the arguments (child nodes) supported by an expression. It maps
+ arg keys to booleans that indicate whether the corresponding args are optional.
parent: a reference to the parent expression (or None, in case of root expressions).
arg_key: the arg key an expression is associated with, i.e. the name its parent expression
uses to refer to it.
@@ -680,7 +680,7 @@ class Expression(metaclass=_Expression):
*expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy the involved expressions (only applies to Expressions).
+ copy: whether to copy the involved expressions (only applies to Expressions).
opts: other options to use to parse the input expressions.
Returns:
@@ -706,7 +706,7 @@ class Expression(metaclass=_Expression):
*expressions: the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy the involved expressions (only applies to Expressions).
+ copy: whether to copy the involved expressions (only applies to Expressions).
opts: other options to use to parse the input expressions.
Returns:
@@ -723,7 +723,7 @@ class Expression(metaclass=_Expression):
'NOT x = 1'
Args:
- copy: whether or not to copy this object.
+ copy: whether to copy this object.
Returns:
The new Not instance.
@@ -3820,7 +3820,7 @@ class DataType(Expression):
dialect: the dialect to use for parsing `dtype`, in case it's a string.
udt: when set to True, `dtype` will be used as-is if it can't be parsed into a
DataType, thus creating a user-defined type.
- copy: whether or not to copy the data type.
+ copy: whether to copy the data type.
kwargs: additional arguments to pass in the constructor of DataType.
Returns:
@@ -4309,9 +4309,9 @@ class Func(Condition):
Attributes:
is_var_len_args (bool): if set to True the last argument defined in arg_types will be
treated as a variable length argument and the argument's value will be stored as a list.
- _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items)
- for this function expression. These values are used to map this node to a name during parsing
- as well as to provide the function's name during SQL string generation. By default the SQL
+ _sql_names (list): the SQL name (1st item in the list) and aliases (subsequent items) for this
+ function expression. These values are used to map this node to a name during parsing as
+ well as to provide the function's name during SQL string generation. By default the SQL
name is set to the expression's class name transformed to snake case.
"""
@@ -4449,6 +4449,7 @@ class ArrayAll(Func):
arg_types = {"this": True, "expression": True}
+# Represents Python's `any(f(x) for x in array)`, where `array` is `this` and `f` is `expression`
class ArrayAny(Func):
arg_types = {"this": True, "expression": True}
@@ -4482,6 +4483,7 @@ class ArrayOverlaps(Binary, Func):
class ArraySize(Func):
arg_types = {"this": True, "expression": False}
+ _sql_names = ["ARRAY_SIZE", "ARRAY_LENGTH"]
class ArraySort(Func):
@@ -5033,7 +5035,7 @@ class JSONBContains(Binary):
class JSONExtract(Binary, Func):
- arg_types = {"this": True, "expression": True, "expressions": False}
+ arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False}
_sql_names = ["JSON_EXTRACT"]
is_var_len_args = True
@@ -5043,7 +5045,7 @@ class JSONExtract(Binary, Func):
class JSONExtractScalar(Binary, Func):
- arg_types = {"this": True, "expression": True, "expressions": False}
+ arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False}
_sql_names = ["JSON_EXTRACT_SCALAR"]
is_var_len_args = True
@@ -5626,7 +5628,7 @@ def maybe_parse(
input expression is a SQL string).
prefix: a string to prefix the sql with before it gets parsed
(automatically includes a space)
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
**opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string).
@@ -5897,7 +5899,7 @@ def union(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
@@ -5931,7 +5933,7 @@ def intersect(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
@@ -5965,7 +5967,7 @@ def except_(
If an `Expression` instance is passed, it will be used as-is.
distinct: set the DISTINCT flag if and only if this is true.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
opts: other options to use to parse the input expressions.
Returns:
@@ -6127,7 +6129,7 @@ def insert(
overwrite: whether to INSERT OVERWRITE or not.
returning: sql conditional parsed into a RETURNING statement
dialect: the dialect used to parse the input expressions.
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
**opts: other options to use to parse the input expressions.
Returns:
@@ -6168,7 +6170,7 @@ def condition(
If an Expression instance is passed, this is used as-is.
dialect: the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
- copy: Whether or not to copy `expression` (only applies to expressions).
+ copy: Whether to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).
@@ -6198,7 +6200,7 @@ def and_(
*expressions: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy `expressions` (only applies to Expressions).
+ copy: whether to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
@@ -6221,7 +6223,7 @@ def or_(
*expressions: the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect: the dialect used to parse the input expression.
- copy: whether or not to copy `expressions` (only applies to Expressions).
+ copy: whether to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.
Returns:
@@ -6296,8 +6298,8 @@ def to_identifier(name, quoted=None, copy=True):
Args:
name: The name to turn into an identifier.
- quoted: Whether or not force quote the identifier.
- copy: Whether or not to copy name if it's an Identifier.
+ quoted: Whether to force quote the identifier.
+ copy: Whether to copy name if it's an Identifier.
Returns:
The identifier ast node.
@@ -6379,7 +6381,7 @@ def to_table(
Args:
sql_path: a `[catalog].[schema].[table]` string.
dialect: the source dialect according to which the table name will be parsed.
- copy: Whether or not to copy a table if it is passed in.
+ copy: Whether to copy a table if it is passed in.
kwargs: the kwargs to instantiate the resulting `Table` expression with.
Returns:
@@ -6418,7 +6420,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column:
def alias_(
expression: ExpOrStr,
- alias: str | Identifier,
+ alias: t.Optional[str | Identifier],
table: bool | t.Sequence[str | Identifier] = False,
quoted: t.Optional[bool] = None,
dialect: DialectType = None,
@@ -6439,10 +6441,10 @@ def alias_(
If an Expression instance is passed, this is used as-is.
alias: the alias name to use. If the name has
special characters it is quoted.
- table: Whether or not to create a table alias, can also be a list of columns.
- quoted: whether or not to quote the alias
+ table: Whether to create a table alias, can also be a list of columns.
+ quoted: whether to quote the alias
dialect: the dialect used to parse the input expression.
- copy: Whether or not to copy the expression.
+ copy: Whether to copy the expression.
**opts: other options to use to parse the input expressions.
Returns:
@@ -6549,7 +6551,7 @@ def column(
catalog: Catalog name.
fields: Additional fields using dots.
quoted: Whether to force quotes on the column's identifiers.
- copy: Whether or not to copy identifiers if passed in.
+ copy: Whether to copy identifiers if passed in.
Returns:
The new Column instance.
@@ -6576,7 +6578,7 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast
Args:
expression: The expression to cast.
to: The datatype to cast to.
- copy: Whether or not to copy the supplied expressions.
+ copy: Whether to copy the supplied expressions.
Returns:
The new Cast instance.
@@ -6704,7 +6706,7 @@ def rename_column(
table_name: Name of the table
old_column: The old name of the column
new_column: The new name of the column
- exists: Whether or not to add the `IF EXISTS` clause
+ exists: Whether to add the `IF EXISTS` clause
Returns:
Alter table expression
@@ -6727,7 +6729,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
Args:
value: A python object.
- copy: Whether or not to copy `value` (only applies to Expressions and collections).
+ copy: Whether to copy `value` (only applies to Expressions and collections).
Returns:
Expression: the equivalent expression object.
@@ -6847,7 +6849,7 @@ def normalize_table_name(table: str | Table, dialect: DialectType = None, copy:
Args:
table: the table to normalize
dialect: the dialect to use for normalization rules
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
Examples:
>>> normalize_table_name("`A-B`.c", dialect="bigquery")
@@ -6872,7 +6874,7 @@ def replace_tables(
expression: expression node to be transformed and replaced.
mapping: mapping of table names.
dialect: the dialect of the mapping table
- copy: whether or not to copy the expression.
+ copy: whether to copy the expression.
Examples:
>>> from sqlglot import exp, parse_one
@@ -6959,7 +6961,7 @@ def expand(
expression: The expression to expand.
sources: A dictionary of name to Subqueryables.
dialect: The dialect of the sources dict.
- copy: Whether or not to copy the expression during transformation. Defaults to True.
+ copy: Whether to copy the expression during transformation. Defaults to True.
Returns:
The transformed expression.
@@ -6993,7 +6995,7 @@ def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwa
Args:
name: the name of the function to build.
args: the args used to instantiate the function of interest.
- copy: whether or not to copy the argument expressions.
+ copy: whether to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
@@ -7096,7 +7098,7 @@ def array(
Args:
expressions: the expressions to add to the array.
- copy: whether or not to copy the argument expressions.
+ copy: whether to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
@@ -7123,7 +7125,7 @@ def tuple_(
Args:
expressions: the expressions to add to the tuple.
- copy: whether or not to copy the argument expressions.
+ copy: whether to copy the argument expressions.
dialect: the source dialect.
kwargs: the kwargs used to instantiate the function of interest.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 4ff5a0e..4bb5005 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -38,19 +38,19 @@ class Generator(metaclass=_Generator):
Generator converts a given syntax tree to the corresponding SQL string.
Args:
- pretty: Whether or not to format the produced SQL string.
+ pretty: Whether to format the produced SQL string.
Default: False.
identify: Determines when an identifier should be quoted. Possible values are:
False (default): Never quote, except in cases where it's mandatory by the dialect.
True or 'always': Always quote.
'safe': Only quote identifiers that are case insensitive.
- normalize: Whether or not to normalize identifiers to lowercase.
+ normalize: Whether to normalize identifiers to lowercase.
Default: False.
- pad: Determines the pad size in a formatted string.
+ pad: The pad size in a formatted string.
Default: 2.
- indent: Determines the indentation size in a formatted string.
+ indent: The indentation size in a formatted string.
Default: 2.
- normalize_functions: Whether or not to normalize all function names. Possible values are:
+ normalize_functions: How to normalize function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
"lower": Convert names to lowercase.
False: Disables function name normalization.
@@ -59,14 +59,14 @@ class Generator(metaclass=_Generator):
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
- leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
+ leading_comma: Whether the comma is leading or trailing in select expressions.
This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
- comments: Whether or not to preserve comments in the output SQL code.
+ comments: Whether to preserve comments in the output SQL code.
Default: True
"""
@@ -97,6 +97,12 @@ class Generator(metaclass=_Generator):
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}",
exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
+ exp.JSONExtract: lambda self, e: self.func(
+ "JSON_EXTRACT", e.this, e.expression, *e.expressions
+ ),
+ exp.JSONExtractScalar: lambda self, e: self.func(
+ "JSON_EXTRACT_SCALAR", e.this, e.expression, *e.expressions
+ ),
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
@@ -134,15 +140,15 @@ class Generator(metaclass=_Generator):
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
}
- # Whether or not null ordering is supported in order by
+ # Whether null ordering is supported in order by
# True: Full Support, None: No support, False: No support in window specifications
NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
- # Whether or not ignore nulls is inside the agg or outside.
+ # Whether ignore nulls is inside the agg or outside.
# FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER
IGNORE_NULLS_IN_FUNC = False
- # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
+ # Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
# Always do union distinct or union all
@@ -151,25 +157,25 @@ class Generator(metaclass=_Generator):
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
- # Whether or not create function uses an AS before the RETURN
+ # Whether create function uses an AS before the RETURN
CREATE_FUNCTION_RETURN_AS = True
- # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
+ # Whether MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
- # Whether or not the INTERVAL expression works only with values like '1 day'
+ # Whether the INTERVAL expression works only with values like '1 day'
SINGLE_STRING_INTERVAL = False
- # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
+ # Whether the plural form of date parts like day (i.e. "days") is supported in INTERVALs
INTERVAL_ALLOWS_PLURAL_FORM = True
- # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
+ # Whether limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
- # Whether or not limit and fetch allows expresions or just limits
+ # Whether limit and fetch allows expresions or just limits
LIMIT_ONLY_LITERALS = False
- # Whether or not a table is allowed to be renamed with a db
+ # Whether a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
# The separator for grouping sets and rollups
@@ -178,105 +184,105 @@ class Generator(metaclass=_Generator):
# The string used for creating an index on a table
INDEX_ON = "ON"
- # Whether or not join hints should be generated
+ # Whether join hints should be generated
JOIN_HINTS = True
- # Whether or not table hints should be generated
+ # Whether table hints should be generated
TABLE_HINTS = True
- # Whether or not query hints should be generated
+ # Whether query hints should be generated
QUERY_HINTS = True
# What kind of separator to use for query hints
QUERY_HINT_SEP = ", "
- # Whether or not comparing against booleans (e.g. x IS TRUE) is supported
+ # Whether comparing against booleans (e.g. x IS TRUE) is supported
IS_BOOL_ALLOWED = True
- # Whether or not to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement
+ # Whether to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement
DUPLICATE_KEY_UPDATE_WITH_SET = True
- # Whether or not to generate the limit as TOP <value> instead of LIMIT <value>
+ # Whether to generate the limit as TOP <value> instead of LIMIT <value>
LIMIT_IS_TOP = False
- # Whether or not to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
+ # Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ...
RETURNING_END = True
- # Whether or not to generate the (+) suffix for columns used in old-style join conditions
+ # Whether to generate the (+) suffix for columns used in old-style join conditions
COLUMN_JOIN_MARKS_SUPPORTED = False
- # Whether or not to generate an unquoted value for EXTRACT's date part argument
+ # Whether to generate an unquoted value for EXTRACT's date part argument
EXTRACT_ALLOWS_QUOTES = True
- # Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
+ # Whether TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax
TZ_TO_WITH_TIME_ZONE = False
- # Whether or not the NVL2 function is supported
+ # Whether the NVL2 function is supported
NVL2_SUPPORTED = True
# https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax
SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE")
- # Whether or not VALUES statements can be used as derived tables.
+ # Whether VALUES statements can be used as derived tables.
# MySQL 5 and Redshift do not allow this, so when False, it will convert
# SELECT * VALUES into SELECT UNION
VALUES_AS_TABLE = True
- # Whether or not the word COLUMN is included when adding a column with ALTER TABLE
+ # Whether the word COLUMN is included when adding a column with ALTER TABLE
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True
# UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery)
UNNEST_WITH_ORDINALITY = True
- # Whether or not FILTER (WHERE cond) can be used for conditional aggregation
+ # Whether FILTER (WHERE cond) can be used for conditional aggregation
AGGREGATE_FILTER_SUPPORTED = True
- # Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
+ # Whether JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds
SEMI_ANTI_JOIN_WITH_SIDE = True
- # Whether or not to include the type of a computed column in the CREATE DDL
+ # Whether to include the type of a computed column in the CREATE DDL
COMPUTED_COLUMN_WITH_TYPE = True
- # Whether or not CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY
+ # Whether CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY
SUPPORTS_TABLE_COPY = True
- # Whether or not parentheses are required around the table sample's expression
+ # Whether parentheses are required around the table sample's expression
TABLESAMPLE_REQUIRES_PARENS = True
- # Whether or not a table sample clause's size needs to be followed by the ROWS keyword
+ # Whether a table sample clause's size needs to be followed by the ROWS keyword
TABLESAMPLE_SIZE_IS_ROWS = True
# The keyword(s) to use when generating a sample clause
TABLESAMPLE_KEYWORDS = "TABLESAMPLE"
- # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
+ # Whether the TABLESAMPLE clause supports a method name, like BERNOULLI
TABLESAMPLE_WITH_METHOD = True
# The keyword to use when specifying the seed of a sample clause
TABLESAMPLE_SEED_KEYWORD = "SEED"
- # Whether or not COLLATE is a function instead of a binary operator
+ # Whether COLLATE is a function instead of a binary operator
COLLATE_IS_FUNC = False
- # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
+ # Whether data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
- # Whether or not conditions require booleans WHERE x = 0 vs WHERE x
+ # Whether conditions require booleans WHERE x = 0 vs WHERE x
ENSURE_BOOLS = False
- # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
+ # Whether the "RECURSIVE" keyword is required when defining recursive CTEs
CTE_RECURSIVE_KEYWORD_REQUIRED = True
- # Whether or not CONCAT requires >1 arguments
+ # Whether CONCAT requires >1 arguments
SUPPORTS_SINGLE_ARG_CONCAT = True
- # Whether or not LAST_DAY function supports a date part argument
+ # Whether LAST_DAY function supports a date part argument
LAST_DAY_SUPPORTS_DATE_PART = True
- # Whether or not named columns are allowed in table aliases
+ # Whether named columns are allowed in table aliases
SUPPORTS_TABLE_ALIAS_COLUMNS = True
- # Whether or not UNPIVOT aliases are Identifiers (False means they're Literals)
+ # Whether UNPIVOT aliases are Identifiers (False means they're Literals)
UNPIVOT_ALIASES_ARE_IDENTIFIERS = True
# What delimiter to use for separating JSON key/value pairs
@@ -285,34 +291,37 @@ class Generator(metaclass=_Generator):
# INSERT OVERWRITE TABLE x override
INSERT_OVERWRITE = " OVERWRITE TABLE"
- # Whether or not the SELECT .. INTO syntax is used instead of CTAS
+ # Whether the SELECT .. INTO syntax is used instead of CTAS
SUPPORTS_SELECT_INTO = False
- # Whether or not UNLOGGED tables can be created
+ # Whether UNLOGGED tables can be created
SUPPORTS_UNLOGGED_TABLES = False
- # Whether or not the CREATE TABLE LIKE statement is supported
+ # Whether the CREATE TABLE LIKE statement is supported
SUPPORTS_CREATE_TABLE_LIKE = True
- # Whether or not the LikeProperty needs to be specified inside of the schema clause
+ # Whether the LikeProperty needs to be specified inside of the schema clause
LIKE_PROPERTY_INSIDE_SCHEMA = False
- # Whether or not DISTINCT can be followed by multiple args in an AggFunc. If not, it will be
+ # Whether DISTINCT can be followed by multiple args in an AggFunc. If not, it will be
# transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args
MULTI_ARG_DISTINCT = True
- # Whether or not the JSON extraction operators expect a value of type JSON
+ # Whether the JSON extraction operators expect a value of type JSON
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
- # Whether or not bracketed keys like ["foo"] are supported in JSON paths
+ # Whether bracketed keys like ["foo"] are supported in JSON paths
JSON_PATH_BRACKETED_KEY_SUPPORTED = True
- # Whether or not to escape keys using single quotes in JSON paths
+ # Whether to escape keys using single quotes in JSON paths
JSON_PATH_SINGLE_QUOTE_ESCAPE = False
# The JSONPathPart expressions supported by this dialect
SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy()
+ # Whether any(f(x) for x in array) can be implemented by this dialect
+ CAN_IMPLEMENT_ARRAY_ANY = False
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -453,7 +462,7 @@ class Generator(metaclass=_Generator):
# Expressions that need to have all CTEs under them bubbled up to them
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
- KEY_VALUE_DEFINITIONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
+ KEY_VALUE_DEFINITIONS = (exp.EQ, exp.PropertyEQ, exp.Slice)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@@ -524,7 +533,7 @@ class Generator(metaclass=_Generator):
Args:
expression: The syntax tree.
- copy: Whether or not to copy the expression. The generator performs mutations so
+ copy: Whether to copy the expression. The generator performs mutations so
it is safer to copy.
Returns:
@@ -3404,6 +3413,21 @@ class Generator(metaclass=_Generator):
return self.func("LAST_DAY", expression.this)
+ def arrayany_sql(self, expression: exp.ArrayAny) -> str:
+ if self.CAN_IMPLEMENT_ARRAY_ANY:
+ filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression)
+ filtered_not_empty = exp.ArraySize(this=filtered).neq(0)
+ original_is_empty = exp.ArraySize(this=expression.this).eq(0)
+ return self.sql(exp.paren(original_is_empty.or_(filtered_not_empty)))
+
+ from sqlglot.dialects import Dialect
+
+ # SQLGlot's executor supports ARRAY_ANY, so we don't wanna warn for the SQLGlot dialect
+ if self.dialect.__class__ != Dialect:
+ self.unsupported("ARRAY_ANY is unsupported")
+
+ return self.function_fallback_sql(expression)
+
def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 6df36af..6bf877b 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -76,7 +76,7 @@ def normalized(expression: exp.Expression, dnf: bool = False) -> bool:
Args:
expression: The expression to check if it's normalized.
- dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
+ dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
"""
ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
@@ -99,7 +99,7 @@ def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int
Args:
expression: The expression to compute the normalization distance for.
- dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF).
+ dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF).
Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:
diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py
index 8d83b47..e4f8b57 100644
--- a/sqlglot/optimizer/qualify.py
+++ b/sqlglot/optimizer/qualify.py
@@ -48,15 +48,15 @@ def qualify(
db: Default database name for tables.
catalog: Default catalog name for tables.
schema: Schema to infer column names and types.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step
+ expand_alias_refs: Whether to expand references to aliases.
+ expand_stars: Whether to expand star queries. This is a necessary step
for most of the optimizer's rules to work; do not set to False unless you
know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
- isolate_tables: Whether or not to isolate table selects.
- qualify_columns: Whether or not to qualify columns.
- validate_qualify_columns: Whether or not to validate columns.
- quote_identifiers: Whether or not to run the quote_identifiers step.
+ infer_schema: Whether to infer the schema if missing.
+ isolate_tables: Whether to isolate table selects.
+ qualify_columns: Whether to qualify columns.
+ validate_qualify_columns: Whether to validate columns.
+ quote_identifiers: Whether to run the quote_identifiers step.
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 5c27bc3..ef589c9 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -35,11 +35,11 @@ def qualify_columns(
Args:
expression: Expression to qualify.
schema: Database schema.
- expand_alias_refs: Whether or not to expand references to aliases.
- expand_stars: Whether or not to expand star queries. This is a necessary step
+ expand_alias_refs: Whether to expand references to aliases.
+ expand_stars: Whether to expand star queries. This is a necessary step
for most of the optimizer's rules to work; do not set to False unless you
know what you're doing!
- infer_schema: Whether or not to infer the schema if missing.
+ infer_schema: Whether to infer the schema if missing.
Returns:
The qualified expression.
@@ -164,12 +164,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
table = table or source_table
conditions.append(
- exp.condition(
- exp.EQ(
- this=exp.column(identifier, table=table),
- expression=exp.column(identifier, table=join_table),
- )
- )
+ exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table))
)
# Set all values in the dict to None, because we only care about the key ordering
@@ -449,10 +444,9 @@ def _expand_stars(
continue
for name in columns:
+ if name in columns_to_exclude or name in coalesced_columns:
+ continue
if name in using_column_tables and table in using_column_tables[name]:
- if name in coalesced_columns:
- continue
-
coalesced_columns.add(name)
tables = using_column_tables[name]
coalesce = [exp.column(name, table=table) for table in tables]
@@ -464,7 +458,7 @@ def _expand_stars(
copy=False,
)
)
- elif name not in columns_to_exclude:
+ else:
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table=table)
new_selections.append(
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 16cd548..0eae979 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -254,7 +254,7 @@ class Scope:
self._columns = []
for column in columns + external_columns:
ancestor = column.find_ancestor(
- exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
+ exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star
)
if (
not ancestor
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 90357dd..9ffddb5 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -41,7 +41,7 @@ def simplify(
Args:
expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
+ constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 26f4159..b4c7475 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -248,7 +248,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
key.replace(exp.to_identifier("_x"))
parent_predicate = _replace(
parent_predicate,
- f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
+ f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
)
parent_select.join(
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 25c5789..4e7f870 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -18,7 +18,7 @@ if t.TYPE_CHECKING:
logger = logging.getLogger("sqlglot")
-def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
+def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
if len(args) == 1 and args[0].is_star:
return exp.StarMap(this=args[0])
@@ -28,13 +28,10 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
keys.append(args[i])
values.append(args[i + 1])
- return exp.VarMap(
- keys=exp.array(*keys, copy=False),
- values=exp.array(*values, copy=False),
- )
+ return exp.VarMap(keys=exp.array(*keys, copy=False), values=exp.array(*values, copy=False))
-def parse_like(args: t.List) -> exp.Escape | exp.Like:
+def build_like(args: t.List) -> exp.Escape | exp.Like:
like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0))
return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like
@@ -47,7 +44,7 @@ def binary_range_parser(
)
-def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
+def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
# Default argument order is base, expression
this = seq_get(args, 0)
expression = seq_get(args, 1)
@@ -60,8 +57,8 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)
-def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
- def _parser(args: t.List, dialect: Dialect) -> E:
+def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
+ def _builder(args: t.List, dialect: Dialect) -> E:
expression = expr_type(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
)
@@ -70,7 +67,7 @@ def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Di
return expression
- return _parser
+ return _builder
class _Parser(type):
@@ -90,8 +87,8 @@ class Parser(metaclass=_Parser):
Args:
error_level: The desired error level.
Default: ErrorLevel.IMMEDIATE
- error_message_context: Determines the amount of context to capture from a
- query string when displaying the error message (in number of characters).
+ error_message_context: The amount of context to capture from a query string when displaying
+ the error message (in number of characters).
Default: 100
max_errors: Maximum number of error messages to include in a raised ParseError.
This is only relevant if error_level is ErrorLevel.RAISE.
@@ -115,11 +112,11 @@ class Parser(metaclass=_Parser):
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)),
- "JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract),
- "JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar),
- "JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar),
- "LIKE": parse_like,
- "LOG": parse_logarithm,
+ "JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract),
+ "JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar),
+ "JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar),
+ "LIKE": build_like,
+ "LOG": build_logarithm,
"TIME_TO_TIME_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
@@ -132,7 +129,7 @@ class Parser(metaclass=_Parser):
start=exp.Literal.number(1),
length=exp.Literal.number(10),
),
- "VAR_MAP": parse_var_map,
+ "VAR_MAP": build_var_map,
}
NO_PAREN_FUNCTIONS = {
@@ -292,6 +289,7 @@ class Parser(metaclass=_Parser):
TokenType.VIEW,
TokenType.MODEL,
TokenType.DICTIONARY,
+ TokenType.STORAGE_INTEGRATION,
}
CREATABLES = {
@@ -550,11 +548,13 @@ class Parser(metaclass=_Parser):
exp.JSONExtract,
this=this,
expression=self.dialect.to_json_path(path),
+ only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE,
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
expression=self.dialect.to_json_path(path),
+ only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE,
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
@@ -983,28 +983,31 @@ class Parser(metaclass=_Parser):
LOG_DEFAULTS_TO_LN = False
- # Whether or not ADD is present for each column added by ALTER TABLE
+ # Whether ADD is present for each column added by ALTER TABLE
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True
- # Whether or not the table sample clause expects CSV syntax
+ # Whether the table sample clause expects CSV syntax
TABLESAMPLE_CSV = False
- # Whether or not the SET command needs a delimiter (e.g. "=") for assignments
+ # Whether the SET command needs a delimiter (e.g. "=") for assignments
SET_REQUIRES_ASSIGNMENT_DELIMITER = True
# Whether the TRIM function expects the characters to trim as its first argument
TRIM_PATTERN_FIRST = False
- # Whether or not string aliases are supported `SELECT COUNT(*) 'count'`
+ # Whether string aliases are supported `SELECT COUNT(*) 'count'`
STRING_ALIASES = False
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
- # Parses no parenthesis if statements as commands
+ # Whether to parse IF statements that aren't followed by a left parenthesis as commands
NO_PAREN_IF_COMMANDS = True
+ # Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres)
+ JSON_ARROWS_REQUIRE_JSON_TYPE = False
+
# Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause.
# If this is True and '(' is not found, the keyword will be treated as an identifier
VALUES_FOLLOWED_BY_PAREN = True
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index dbd0caa..36022b9 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -92,7 +92,7 @@ class Schema(abc.ABC):
normalize: t.Optional[bool] = None,
) -> bool:
"""
- Returns whether or not `column` appears in `table`'s schema.
+ Returns whether `column` appears in `table`'s schema.
Args:
table: the source table.
@@ -115,7 +115,7 @@ class Schema(abc.ABC):
@property
def empty(self) -> bool:
- """Returns whether or not the schema is empty."""
+ """Returns whether the schema is empty."""
return True
@@ -162,7 +162,7 @@ class AbstractMappingSchema:
Args:
table: the target table.
- raise_on_missing: whether or not to raise in case the schema is not found.
+ raise_on_missing: whether to raise in case the schema is not found.
Returns:
The schema of the target table.
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 2cfcfa6..939ca18 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -346,6 +346,7 @@ class TokenType(AutoName):
SOME = auto()
SORT_BY = auto()
START_WITH = auto()
+ STORAGE_INTEGRATION = auto()
STRUCT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
@@ -577,7 +578,7 @@ class Tokenizer(metaclass=_Tokenizer):
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
- # Whether or not the heredoc tags follow the same lexical rules as unquoted identifiers
+ # Whether the heredoc tags follow the same lexical rules as unquoted identifiers
HEREDOC_TAG_IS_IDENTIFIER = False
# Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc