diff options
Diffstat (limited to 'sqlglot')
32 files changed, 705 insertions, 683 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 133979a..308b639 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -1030,7 +1030,7 @@ def posexplode_outer(col: ColumnOrName) -> Column: def get_json_object(col: ColumnOrName, path: str) -> Column: - return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path)) + return Column.invoke_expression_over_column(col, expression.JSONExtract, expression=lit(path)) def json_tuple(col: ColumnOrName, *fields: str) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index c0191b2..f867617 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,13 +12,14 @@ from sqlglot.dialects.dialect import ( binary_from_function, date_add_interval_sql, datestrtodate_sql, - format_time_lambda, + build_formatted_time, + filter_array_using_unnest, if_sql, inline_array_sql, max_or_greatest, min_or_least, no_ilike_sql, - parse_date_delta_with_interval, + build_date_delta_with_interval, regexp_replace_sql, rename_func, timestrtotime_sql, @@ -37,56 +38,33 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) + structs = [] alias = expression.args.get("alias") + for tup in expression.find_all(exp.Tuple): + field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions))) + expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)] + structs.append(exp.Struct(expressions=expressions)) - return self.unnest_sql( - exp.Unnest( - expressions=[ - exp.array( - *( - exp.Struct( - expressions=[ - exp.alias_(value, column_name) - for value, column_name in zip( - t.expressions, - ( - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(t.expressions))) - ), - ) - ] - ) - for t in expression.find_all(exp.Tuple) - ), - copy=False, - ) - ] - ) - ) + return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)])) def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): - this = f"{this.this} <{self.expressions(this)}>" + this = f"{self.sql(this, 'this')} <{self.expressions(this)}>" else: this = self.sql(this) return f"RETURNS {this}" def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: - kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) - - if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): + if expression.kind == "FUNCTION" and returns and returns.args.get("is_table"): expression.set("kind", "TABLE FUNCTION") if isinstance(expression.expression, (exp.Subquery, exp.Literal)): expression.set("expression", expression.expression.this) - return self.create_sql(expression) - return self.create_sql(expression) @@ -132,11 +110,10 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: if isinstance(select, exp.Alias) } - for e in group.expressions: - alias = aliases.get(e) - + for grouped in group.expressions: + alias = aliases.get(grouped) if alias: - e.replace(exp.column(alias)) + grouped.replace(exp.column(alias)) return expression @@ -168,24 +145,24 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: return expression -def _parse_parse_timestamp(args: t.List) -> exp.StrToTime: - this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) +def _build_parse_timestamp(args: t.List) -> exp.StrToTime: + this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) this.set("zone", seq_get(args, 2)) return this -def _parse_timestamp(args: t.List) -> exp.Timestamp: +def _build_timestamp(args: t.List) -> exp.Timestamp: timestamp = exp.Timestamp.from_arg_list(args) timestamp.set("with_tz", True) return timestamp -def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts: +def _build_date(args: t.List) -> exp.Date | exp.DateFromParts: expr_type = exp.DateFromParts if len(args) == 3 else exp.Date return expr_type.from_arg_list(args) -def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: +def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5: # TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation arg = seq_get(args, 0) return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg) @@ -214,18 +191,20 @@ def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") + timestamp = expression.this + if scale in (None, exp.UnixToTime.SECONDS): - return f"TIMESTAMP_SECONDS({timestamp})" + return self.func("TIMESTAMP_SECONDS", timestamp) if scale == exp.UnixToTime.MILLIS: - return f"TIMESTAMP_MILLIS({timestamp})" + return self.func("TIMESTAMP_MILLIS", timestamp) if scale == exp.UnixToTime.MICROS: - return f"TIMESTAMP_MICROS({timestamp})" + return self.func("TIMESTAMP_MICROS", timestamp) - return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))" + unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64") + return self.func("TIMESTAMP_SECONDS", unix_seconds) -def _parse_time(args: t.List) -> exp.Func: +def _build_time(args: t.List) -> exp.Func: if len(args) == 1: return exp.TsOrDsToTime(this=args[0]) if len(args) == 3: @@ -323,6 +302,7 @@ class BigQuery(Dialect): "BYTES": TokenType.BINARY, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "DECLARE": TokenType.COMMAND, + "EXCEPTION": TokenType.COMMAND, "FLOAT64": TokenType.DOUBLE, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, "MODEL": TokenType.MODEL, @@ -340,15 +320,15 @@ class BigQuery(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE": _parse_date, - "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), - "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DATE": _build_date, + "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), + "DATE_SUB": build_date_delta_with_interval(exp.DateSub), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), ), - "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), - "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), + "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd), + "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub), "DIV": binary_from_function(exp.IntDiv), "FORMAT_DATE": lambda args: exp.TimeToStr( this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0) @@ -358,11 +338,11 @@ class BigQuery(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$") ), "MD5": exp.MD5Digest.from_arg_list, - "TO_HEX": _parse_to_hex, - "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( + "TO_HEX": _build_to_hex, + "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), - "PARSE_TIMESTAMP": _parse_parse_timestamp, + "PARSE_TIMESTAMP": _build_parse_timestamp, "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -378,12 +358,12 @@ class BigQuery(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(","), ), - "TIME": _parse_time, - "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), - "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), - "TIMESTAMP": _parse_timestamp, - "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), - "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), + "TIME": _build_time, + "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd), + "TIME_SUB": build_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP": _build_timestamp, + "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd), + "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub), "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( this=seq_get(args, 0), scale=exp.UnixToTime.MICROS ), @@ -424,7 +404,7 @@ class BigQuery(Dialect): } RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy() - RANGE_PARSERS.pop(TokenType.OVERLAPS, None) + RANGE_PARSERS.pop(TokenType.OVERLAPS) NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN} @@ -551,6 +531,7 @@ class BigQuery(Dialect): NULL_ORDERING_SUPPORTED = False IGNORE_NULLS_IN_FUNC = True JSON_PATH_SINGLE_QUOTE_ESCAPE = True + CAN_IMPLEMENT_ARRAY_ANY = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -558,6 +539,7 @@ class BigQuery(Dialect): exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArrayContains: _array_contains_sql, + exp.ArrayFilter: filter_array_using_unnest, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), exp.CollateProperty: lambda self, e: ( @@ -565,12 +547,14 @@ class BigQuery(Dialect): if e.args.get("default") else f"COLLATE {self.sql(e, 'this')}" ), + exp.Commit: lambda *_: "COMMIT TRANSACTION", exp.CountIf: rename_func("COUNTIF"), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), - exp.DateDiff: lambda self, - e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", e.this, e.expression, e.unit or "DAY" + ), exp.DateFromParts: rename_func("DATE"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: date_add_interval_sql("DATE", "SUB"), @@ -602,6 +586,7 @@ class BigQuery(Dialect): exp.RegexpReplace: regexp_replace_sql, exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.ReturnsProperty: _returnsproperty_sql, + exp.Rollback: lambda *_: "ROLLBACK TRANSACTION", exp.Select: transforms.preprocess( [ transforms.explode_to_unnest(), @@ -617,8 +602,7 @@ class BigQuery(Dialect): exp.StabilityProperty: lambda self, e: ( "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" ), - exp.StrToDate: lambda self, - e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", + exp.StrToDate: lambda self, e: self.func("PARSE_DATE", self.format_time(e), e.this), exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") ), @@ -629,6 +613,7 @@ class BigQuery(Dialect): exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"), exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.Transaction: lambda *_: "BEGIN TRANSACTION", exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: _ts_or_ds_diff_sql, @@ -778,12 +763,8 @@ class BigQuery(Dialect): } def timetostr_sql(self, expression: exp.TimeToStr) -> str: - if isinstance(expression.this, exp.TsOrDsToDate): - this: exp.Expression = expression.this - else: - this = expression - - return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})" + this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression + return self.func("FORMAT_DATE", self.format_time(expression), this.this) def struct_sql(self, expression: exp.Struct) -> str: args = [] @@ -820,11 +801,6 @@ class BigQuery(Dialect): def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") - def cte_sql(self, expression: exp.CTE) -> str: - if expression.alias_column_names: - self.unsupported("Column names in CTE definition are not supported.") - return super().cte_sql(expression) - def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) if isinstance(first_arg, exp.Subqueryable): @@ -862,25 +838,16 @@ class BigQuery(Dialect): return f"{this}[{expressions_sql}]" - def transaction_sql(self, *_) -> str: - return "BEGIN TRANSACTION" - - def commit_sql(self, *_) -> str: - return "COMMIT TRANSACTION" - - def rollback_sql(self, *_) -> str: - return "ROLLBACK TRANSACTION" - def in_unnest_op(self, expression: exp.Unnest) -> str: return self.sql(expression) def except_op(self, expression: exp.Except) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" def intersect_op(self, expression: exp.Intersect) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index d7be64c..05d6a03 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -11,13 +11,12 @@ from sqlglot.dialects.dialect import ( json_extract_segments, json_path_key_only_name, no_pivot_sql, - parse_json_extract_path, + build_json_extract_path, rename_func, var_map_sql, ) from sqlglot.errors import ParseError from sqlglot.helper import is_int, seq_get -from sqlglot.parser import parse_var_map from sqlglot.tokens import Token, TokenType @@ -26,9 +25,9 @@ def _lower_func(sql: str) -> str: return sql[:index].lower() + sql[index:] -def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str: - quantile = e.args["quantile"] - args = f"({self.sql(e, 'this')})" +def _quantile_sql(self: ClickHouse.Generator, expression: exp.Quantile) -> str: + quantile = expression.args["quantile"] + args = f"({self.sql(expression, 'this')})" if isinstance(quantile, exp.Array): func = self.func("quantiles", *quantile) @@ -38,7 +37,7 @@ def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str: return func + args -def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc: +def _build_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc: if len(args) == 1: return exp.CountIf(this=seq_get(args, 0)) @@ -111,7 +110,7 @@ class ClickHouse(Dialect): **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, "ARRAYSUM": exp.ArraySum.from_arg_list, - "COUNTIF": _parse_count_if, + "COUNTIF": _build_count_if, "DATE_ADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -124,10 +123,10 @@ class ClickHouse(Dialect): "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), - "JSONEXTRACTSTRING": parse_json_extract_path( + "JSONEXTRACTSTRING": build_json_extract_path( exp.JSONExtractScalar, zero_based_indexing=False ), - "MAP": parse_var_map, + "MAP": parser.build_var_map, "MATCH": exp.RegexpLike.from_arg_list, "RANDCANONICAL": exp.Rand.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, @@ -417,9 +416,9 @@ class ClickHouse(Dialect): self, skip_join_token: bool = False, parse_bracket: bool = False ) -> t.Optional[exp.Join]: join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True) - if join: join.set("global", join.args.pop("method", None)) + return join def _parse_function( @@ -516,6 +515,7 @@ class ClickHouse(Dialect): TABLESAMPLE_SIZE_IS_ROWS = False TABLESAMPLE_KEYWORDS = "SAMPLE" LAST_DAY_SUPPORTS_DATE_PART = False + CAN_IMPLEMENT_ARRAY_ANY = True STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -576,6 +576,8 @@ class ClickHouse(Dialect): **generator.Generator.TRANSFORMS, exp.AnyValue: rename_func("any"), exp.ApproxDistinct: rename_func("uniq"), + exp.ArrayFilter: lambda self, e: self.func("arrayFilter", e.expression, e.this), + exp.ArraySize: rename_func("LENGTH"), exp.ArraySum: rename_func("arraySum"), exp.ArgMax: arg_max_or_min_no_count("argMax"), exp.ArgMin: arg_max_or_min_no_count("argMin"), @@ -597,12 +599,13 @@ class ClickHouse(Dialect): exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, - exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", + exp.RegexpLike: lambda self, e: self.func("match", e.this, e.expression), exp.Rand: rename_func("randCanonical"), exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.StartsWith: rename_func("startsWith"), - exp.StrPosition: lambda self, - e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", + exp.StrPosition: lambda self, e: self.func( + "position", e.this, e.args.get("substr"), e.args.get("position") + ), exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), } @@ -652,6 +655,7 @@ class ClickHouse(Dialect): this = expression.left else: return default(expression) + return prefix + self.func("has", arr.this.unnest(), this) def eq_sql(self, expression: exp.EQ) -> str: @@ -663,7 +667,7 @@ class ClickHouse(Dialect): def regexpilike_sql(self, expression: exp.RegexpILike) -> str: # Manually add a flag to make the search case-insensitive regex = self.func("CONCAT", "'(?i)'", expression.expression) - return f"match({self.format_args(expression.this, regex)})" + return self.func("match", expression.this, regex) def datatype_sql(self, expression: exp.DataType) -> str: # String is the standard ClickHouse type, every other variant is just an alias. @@ -717,8 +721,9 @@ class ClickHouse(Dialect): return f"ON CLUSTER {self.sql(expression, 'this')}" def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: - kind = self.sql(expression, "kind").upper() - if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME): + if expression.kind in self.ON_CLUSTER_TARGETS and locations.get( + exp.Properties.Location.POST_NAME + ): this_name = self.sql(expression.this, "this") this_properties = " ".join( [self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]] diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 20907db..96eff18 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -3,13 +3,19 @@ from __future__ import annotations from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( date_delta_sql, - parse_date_delta, + build_date_delta, timestamptrunc_sql, ) from sqlglot.dialects.spark import Spark from sqlglot.tokens import TokenType +def _timestamp_diff( + self: Databricks.Generator, expression: exp.DatetimeDiff | exp.TimestampDiff +) -> str: + return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) + + class Databricks(Spark): SAFE_DIVISION = False @@ -19,10 +25,10 @@ class Databricks(Spark): FUNCTIONS = { **Spark.Parser.FUNCTIONS, - "DATEADD": parse_date_delta(exp.DateAdd), - "DATE_ADD": parse_date_delta(exp.DateAdd), - "DATEDIFF": parse_date_delta(exp.DateDiff), - "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff), + "DATEADD": build_date_delta(exp.DateAdd), + "DATE_ADD": build_date_delta(exp.DateAdd), + "DATEDIFF": build_date_delta(exp.DateDiff), + "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff), } FACTOR = { @@ -38,20 +44,16 @@ class Databricks(Spark): exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DatetimeAdd: lambda self, e: self.func( - "TIMESTAMPADD", e.text("unit"), e.expression, e.this + "TIMESTAMPADD", e.unit, e.expression, e.this ), exp.DatetimeSub: lambda self, e: self.func( "TIMESTAMPADD", - e.text("unit"), + e.unit, exp.Mul(this=e.expression, expression=exp.Literal.number(-1)), e.this, ), - exp.DatetimeDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this - ), - exp.TimestampDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this - ), + exp.DatetimeDiff: _timestamp_diff, + exp.TimestampDiff: _timestamp_diff, exp.DatetimeTrunc: timestamptrunc_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), exp.Select: transforms.preprocess( @@ -75,6 +77,7 @@ class Databricks(Spark): ): # only BIGINT generated identity constraints are supported expression.set("kind", exp.DataType.build("bigint")) + return super().columndef_sql(expression, sep) def generatedasidentitycolumnconstraint_sql( diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0440a99..b0a78d2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -17,12 +17,12 @@ from sqlglot.trie import new_trie DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] +JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] + if t.TYPE_CHECKING: from sqlglot._typing import B, E, F - JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] - logger = logging.getLogger("sqlglot") @@ -148,47 +148,53 @@ class _Dialect(type): class Dialect(metaclass=_Dialect): INDEX_OFFSET = 0 - """Determines the base index offset for arrays.""" + """The base index offset for arrays.""" WEEK_OFFSET = 0 - """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" + """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" UNNEST_COLUMN_ONLY = False - """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" + """Whether `UNNEST` table aliases are treated as column aliases.""" ALIAS_POST_TABLESAMPLE = False - """Determines whether or not the table alias comes after tablesample.""" + """Whether the table alias comes after tablesample.""" TABLESAMPLE_SIZE_IS_PERCENT = False - """Determines whether or not a size in the table sample clause represents percentage.""" + """Whether a size in the table sample clause represents percentage.""" NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE """Specifies the strategy according to which identifiers should be normalized.""" IDENTIFIERS_CAN_START_WITH_DIGIT = False - """Determines whether or not an unquoted identifier can start with a digit.""" + """Whether an unquoted identifier can start with a digit.""" DPIPE_IS_STRING_CONCAT = True - """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" + """Whether the DPIPE token (`||`) is a string concatenation operator.""" STRICT_STRING_CONCAT = False - """Determines whether or not `CONCAT`'s arguments must be strings.""" + """Whether `CONCAT`'s arguments must be strings.""" SUPPORTS_USER_DEFINED_TYPES = True - """Determines whether or not user-defined data types are supported.""" + """Whether user-defined data types are supported.""" SUPPORTS_SEMI_ANTI_JOIN = True - """Determines whether or not `SEMI` or `ANTI` joins are supported.""" + """Whether `SEMI` or `ANTI` joins are supported.""" NORMALIZE_FUNCTIONS: bool | str = "upper" - """Determines how function names are going to be normalized.""" + """ + Determines how function names are going to be normalized. + Possible values: + "upper" or True: Convert names to uppercase. + "lower": Convert names to lowercase. + False: Disables function name normalization. + """ LOG_BASE_FIRST = True - """Determines whether the base comes first in the `LOG` function.""" + """Whether the base comes first in the `LOG` function.""" NULL_ORDERING = "nulls_are_small" """ - Indicates the default `NULL` ordering method to use if not explicitly set. + Default `NULL` ordering method to use if not explicitly set. Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` """ @@ -200,7 +206,7 @@ class Dialect(metaclass=_Dialect): """ SAFE_DIVISION = False - """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" + """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" CONCAT_COALESCE = False """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" @@ -210,7 +216,7 @@ class Dialect(metaclass=_Dialect): TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" TIME_MAPPING: t.Dict[str, str] = {} - """Associates this dialect's time formats with their equivalent Python `strftime` format.""" + """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE @@ -418,7 +424,7 @@ class Dialect(metaclass=_Dialect): `"safe"`: Only returns `True` if the identifier is case-insensitive. Returns: - Whether or not the given text can be identified. + Whether the given text can be identified. """ if identify is True or identify == "always": return True @@ -614,7 +620,7 @@ def var_map_sql( return self.func(map_func_name, *args) -def format_time_lambda( +def build_formatted_time( exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None ) -> t.Callable[[t.List], E]: """Helper used for time expressions. @@ -628,7 +634,7 @@ def format_time_lambda( A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args: t.List): + def _builder(args: t.List): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( @@ -637,7 +643,7 @@ def format_time_lambda( ), ) - return _format_time + return _builder def time_format( @@ -654,23 +660,23 @@ def time_format( return _time_format -def parse_date_delta( +def build_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: - def inner_func(args: t.List) -> E: + def _builder(args: t.List) -> E: unit_based = len(args) == 3 this = args[2] if unit_based else seq_get(args, 0) unit = args[0] if unit_based else exp.Literal.string("DAY") unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit return exp_class(this=this, expression=seq_get(args, 1), unit=unit) - return inner_func + return _builder -def parse_date_delta_with_interval( +def build_date_delta_with_interval( expression_class: t.Type[E], ) -> t.Callable[[t.List], t.Optional[E]]: - def func(args: t.List) -> t.Optional[E]: + def _builder(args: t.List) -> t.Optional[E]: if len(args) < 2: return None @@ -687,7 +693,7 @@ def parse_date_delta_with_interval( this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) ) - return func + return _builder def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: @@ -888,7 +894,7 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: # Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects -def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: +def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) @@ -991,10 +997,10 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: return self.merge_sql(expression) -def parse_json_extract_path( +def build_json_extract_path( expr_type: t.Type[F], zero_based_indexing: bool = True ) -> t.Callable[[t.List], F]: - def _parse_json_extract_path(args: t.List) -> F: + def _builder(args: t.List) -> F: segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] for arg in args[1:]: if not isinstance(arg, exp.Literal): @@ -1014,11 +1020,11 @@ def parse_json_extract_path( del args[2:] return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) - return _parse_json_extract_path + return _builder def json_extract_segments( - name: str, quoted_index: bool = True + name: str, quoted_index: bool = True, op: t.Optional[str] = None ) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: path = expression.expression @@ -1036,6 +1042,8 @@ def json_extract_segments( segments.append(path) + if op: + return f" {op} ".join([self.sql(expression.this), *segments]) return self.func(name, expression.this, *segments) return _json_extract_segments @@ -1046,3 +1054,19 @@ def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str self.unsupported("Unsupported wildcard in JSONPathKey expression") return expression.name + + +def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: + cond = expression.expression + if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: + alias = cond.expressions[0] + cond = cond.this + elif isinstance(cond, exp.Predicate): + alias = "_u" + else: + self.unsupported("Unsupported filter condition") + return "" + + unnest = exp.Unnest(expressions=[expression.this]) + filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) + return self.sql(exp.Array(expressions=[filtered])) diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 7a18e8e..067a045 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -4,7 +4,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_sql, - parse_timestamp_trunc, + build_timestamp_trunc, rename_func, time_format, ) @@ -20,7 +20,7 @@ class Doris(MySQL): FUNCTIONS = { **MySQL.Parser.FUNCTIONS, "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, - "DATE_TRUNC": parse_timestamp_trunc, + "DATE_TRUNC": build_timestamp_trunc, "REGEXP": exp.RegexpLike.from_arg_list, "TO_DATE": exp.TsOrDsToDate.from_arg_list, } @@ -46,7 +46,7 @@ class Doris(MySQL): exp.ArgMin: rename_func("MIN_BY"), exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), - exp.CurrentTimestamp: lambda *_: "NOW()", + exp.CurrentTimestamp: lambda self, _: self.func("NOW"), exp.DateTrunc: lambda self, e: self.func( "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" ), @@ -55,14 +55,11 @@ class Doris(MySQL): exp.Map: rename_func("ARRAY_MAP"), exp.RegexpLike: rename_func("REGEXP"), exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), - exp.StrToUnix: lambda self, - e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), exp.Split: rename_func("SPLIT_BY_STRING"), exp.TimeStrToDate: rename_func("TO_DATE"), - exp.ToChar: lambda self, - e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TsOrDsAdd: lambda self, - e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level + exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), + exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression), exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimestampTrunc: lambda self, e: self.func( diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 409e260..4e699f5 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, - format_time_lambda, + build_formatted_time, no_trycast_sql, rename_func, str_position_sql, @@ -19,9 +19,7 @@ def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.D def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" - ) + return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit)) return func @@ -30,8 +28,8 @@ def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Drill.DATE_FORMAT: - return f"CAST({this} AS DATE)" - return f"TO_DATE({this}, {time_format})" + return self.sql(exp.cast(this, "date")) + return self.func("TO_DATE", this, time_format) class Drill(Dialect): @@ -86,9 +84,9 @@ class Drill(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"), + "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, - "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), + "TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"), } LOG_DEFAULTS_TO_LN = True @@ -135,8 +133,7 @@ class Drill(Dialect): e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})", exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", - exp.ILike: lambda self, - e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", + exp.ILike: lambda self, e: self.binary(e, "`ILIKE`"), exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.RegexpLike: rename_func("REGEXP_MATCHES"), @@ -146,12 +143,11 @@ class Drill(Dialect): exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), + exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")), exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index e61ac4f..925c5ae 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -14,7 +14,7 @@ from sqlglot.dialects.dialect import ( date_trunc_to_time, datestrtodate_sql, encode_decode_sql, - format_time_lambda, + build_formatted_time, inline_array_sql, no_comment_column_constraint_sql, no_safe_divide_sql, @@ -62,26 +62,24 @@ def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str: def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str: if expression.expression: - self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") - return f"ARRAY_SORT({self.sql(expression, 'this')})" + self.unsupported("DuckDB ARRAY_SORT does not support a comparator") + return self.func("ARRAY_SORT", expression.this) def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str: - this = self.sql(expression, "this") - if expression.args.get("asc") == exp.false(): - return f"ARRAY_REVERSE_SORT({this})" - return f"ARRAY_SORT({this})" + name = "ARRAY_REVERSE_SORT" if expression.args.get("asc") == exp.false() else "ARRAY_SORT" + return self.func(name, expression.this) -def _sort_array_reverse(args: t.List) -> exp.Expression: +def _build_sort_array_desc(args: t.List) -> exp.Expression: return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) -def _parse_date_diff(args: t.List) -> exp.Expression: +def _build_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) -def _parse_make_timestamp(args: t.List) -> exp.Expression: +def _build_make_timestamp(args: t.List) -> exp.Expression: if len(args) == 1: return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS) @@ -103,10 +101,7 @@ def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: value = expr.this else: key = expr.name or expr.this.name - if isinstance(expr, exp.Bracket): - value = expr.expressions[0] - else: - value = expr.expression + value = expr.expression args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}") @@ -131,15 +126,16 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str: def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") + timestamp = expression.this + if scale in (None, exp.UnixToTime.SECONDS): - return f"TO_TIMESTAMP({timestamp})" + return self.func("TO_TIMESTAMP", timestamp) if scale == exp.UnixToTime.MILLIS: - return f"EPOCH_MS({timestamp})" + return self.func("EPOCH_MS", timestamp) if scale == exp.UnixToTime.MICROS: - return f"MAKE_TIMESTAMP({timestamp})" + return self.func("MAKE_TIMESTAMP", timestamp) - return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))" + return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale))) def _rename_unless_within_group( @@ -152,7 +148,7 @@ def _rename_unless_within_group( ) -def _parse_struct_pack(args: t.List) -> exp.Struct: +def _build_struct_pack(args: t.List) -> exp.Struct: args_with_columns_as_identifiers = [ exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args ] @@ -220,11 +216,10 @@ class DuckDB(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_HAS": exp.ArrayContains.from_arg_list, - "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, - "ARRAY_REVERSE_SORT": _sort_array_reverse, - "DATEDIFF": _parse_date_diff, - "DATE_DIFF": _parse_date_diff, + "ARRAY_REVERSE_SORT": _build_sort_array_desc, + "DATEDIFF": _build_date_diff, + "DATE_DIFF": _build_date_diff, "DATE_TRUNC": date_trunc_to_time, "DATETRUNC": date_trunc_to_time, "DECODE": lambda args: exp.Decode( @@ -238,14 +233,14 @@ class DuckDB(Dialect): this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), "JSON": exp.ParseJSON.from_arg_list, - "JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract), - "JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar), + "JSON_EXTRACT_PATH": parser.build_extract_json_with_path(exp.JSONExtract), + "JSON_EXTRACT_STRING": parser.build_extract_json_with_path(exp.JSONExtractScalar), "LIST_HAS": exp.ArrayContains.from_arg_list, - "LIST_REVERSE_SORT": _sort_array_reverse, + "LIST_REVERSE_SORT": _build_sort_array_desc, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, "MAKE_TIME": exp.TimeFromParts.from_arg_list, - "MAKE_TIMESTAMP": _parse_make_timestamp, + "MAKE_TIMESTAMP": _build_make_timestamp, "MEDIAN": lambda args: exp.PercentileCont( this=seq_get(args, 0), expression=exp.Literal.number(0.5) ), @@ -261,12 +256,12 @@ class DuckDB(Dialect): replacement=seq_get(args, 2), modifiers=seq_get(args, 3), ), - "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), + "STRFTIME": build_formatted_time(exp.TimeToStr, "duckdb"), "STRING_SPLIT": exp.Split.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_TO_ARRAY": exp.Split.from_arg_list, - "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), - "STRUCT_PACK": _parse_struct_pack, + "STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"), + "STRUCT_PACK": _build_struct_pack, "STR_SPLIT": exp.Split.from_arg_list, "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, @@ -275,7 +270,7 @@ class DuckDB(Dialect): } FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() - FUNCTION_PARSERS.pop("DECODE", None) + FUNCTION_PARSERS.pop("DECODE") TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.SEMI, @@ -334,6 +329,7 @@ class DuckDB(Dialect): JSON_PATH_BRACKETED_KEY_SUPPORTED = False SUPPORTS_CREATE_TABLE_LIKE = False MULTI_ARG_DISTINCT = False + CAN_IMPLEMENT_ARRAY_ANY = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -343,6 +339,7 @@ class DuckDB(Dialect): if e.expressions and e.expressions[0].find(exp.Select) else inline_array_sql(self, e) ), + exp.ArrayFilter: rename_func("LIST_FILTER"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"), @@ -350,9 +347,9 @@ class DuckDB(Dialect): exp.ArraySum: rename_func("LIST_SUM"), exp.BitwiseXor: rename_func("XOR"), exp.CommentColumnConstraint: no_comment_column_constraint_sql, - exp.CurrentDate: lambda self, e: "CURRENT_DATE", - exp.CurrentTime: lambda self, e: "CURRENT_TIME", - exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP", + exp.CurrentDate: lambda *_: "CURRENT_DATE", + exp.CurrentTime: lambda *_: "CURRENT_TIME", + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -409,19 +406,19 @@ class DuckDB(Dialect): exp.StrPosition: str_position_sql, exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)", exp.StrToTime: str_to_time_sql, - exp.StrToUnix: lambda self, - e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StrToUnix: lambda self, e: self.func( + "EPOCH", self.func("STRPTIME", e.this, self.format_time(e)) + ), exp.Struct: _struct_sql, exp.Timestamp: no_timestamp_sql, exp.TimestampDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this ), exp.TimestampTrunc: timestamptrunc_sql, - exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")), exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", - exp.TimeToStr: lambda self, - e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")), + exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)), exp.TimeToUnix: rename_func("EPOCH"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", @@ -432,8 +429,9 @@ class DuckDB(Dialect): exp.cast(e.expression, "TIMESTAMP"), exp.cast(e.this, "TIMESTAMP"), ), - exp.UnixToStr: lambda self, - e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToStr: lambda self, e: self.func( + "STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e) + ), exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", exp.VariancePop: rename_func("VAR_POP"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index b1540bb..43211dc 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arg_max_or_min_no_count, datestrtodate_sql, - format_time_lambda, + build_formatted_time, if_sql, is_parse_json, left_to_substring_sql, @@ -38,7 +38,6 @@ from sqlglot.transforms import ( move_schema_columns_to_partitioned_by, ) from sqlglot.helper import seq_get -from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType # (FuncType, Multiplier) @@ -130,7 +129,7 @@ def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") - return f"SORT_ARRAY({self.sql(expression, 'this')})" + return self.func("SORT_ARRAY", expression.this) def _property_sql(self: Hive.Generator, expression: exp.Property) -> str: @@ -157,23 +156,18 @@ def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str: return f"CAST({this} AS TIMESTAMP)" -def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str: - this = self.sql(expression, "this") - time_format = self.format_time(expression) - return f"DATE_FORMAT({this}, {time_format})" - - def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: - this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): - return f"TO_DATE({this}, {time_format})" + return self.func("TO_DATE", expression.this, time_format) + if isinstance(expression.this, exp.TsOrDsToDate): - return this - return f"TO_DATE({this})" + return self.sql(expression, "this") + + return self.func("TO_DATE", expression.this) -def _parse_ignore_nulls( +def _build_with_ignore_nulls( exp_class: t.Type[exp.Expression], ) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: def _parse(args: t.List[exp.Expression]) -> exp.Expression: @@ -276,7 +270,7 @@ class Hive(Dialect): "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), - "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( + "DATE_FORMAT": lambda args: build_formatted_time(exp.TimeToStr, "hive")( [ exp.TimeStrToTime(this=seq_get(args, 0)), seq_get(args, 1), @@ -292,14 +286,14 @@ class Hive(Dialect): expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "FIRST": _parse_ignore_nulls(exp.First), - "FIRST_VALUE": _parse_ignore_nulls(exp.FirstValue), - "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), + "FIRST": _build_with_ignore_nulls(exp.First), + "FIRST_VALUE": _build_with_ignore_nulls(exp.FirstValue), + "FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, - "LAST": _parse_ignore_nulls(exp.Last), - "LAST_VALUE": _parse_ignore_nulls(exp.LastValue), + "LAST": _build_with_ignore_nulls(exp.Last), + "LAST_VALUE": _build_with_ignore_nulls(exp.LastValue), "LOCATE": locate_to_strposition, - "MAP": parse_var_map, + "MAP": parser.build_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, @@ -313,10 +307,10 @@ class Hive(Dialect): pair_delim=seq_get(args, 1) or exp.Literal.string(","), key_value_delim=seq_get(args, 2) or exp.Literal.string(":"), ), - "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), + "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, "UNBASE64": exp.FromBase64.from_arg_list, - "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), + "UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } @@ -487,8 +481,10 @@ class Hive(Dialect): exp.If: if_sql(), exp.ILike: no_ilike_sql, exp.IsNan: rename_func("ISNAN"), - exp.JSONExtract: rename_func("GET_JSON_OBJECT"), - exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), + exp.JSONExtract: lambda self, e: self.func("GET_JSON_OBJECT", e.this, e.expression), + exp.JSONExtractScalar: lambda self, e: self.func( + "GET_JSON_OBJECT", e.this, e.expression + ), exp.JSONFormat: _json_format_sql, exp.Left: left_to_substring_sql, exp.Map: var_map_sql, @@ -496,7 +492,7 @@ class Hive(Dialect): exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), - exp.NotNullColumnConstraint: lambda self, e: ( + exp.NotNullColumnConstraint: lambda _, e: ( "" if e.args.get("allow_null") else "NOT NULL" ), exp.VarMap: var_map_sql, @@ -517,8 +513,9 @@ class Hive(Dialect): exp.SafeDivide: no_safe_divide_sql, exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), - exp.Split: lambda self, - e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", + exp.Split: lambda self, e: self.func( + "SPLIT", e.this, self.func("CONCAT", "'\\\\Q'", e.expression) + ), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_time_sql, @@ -527,7 +524,7 @@ class Hive(Dialect): exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeToStr: _time_to_str, + exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToBase64: rename_func("BASE64"), exp.TsOrDiToDi: lambda self, @@ -549,9 +546,9 @@ class Hive(Dialect): e: f"({self.expressions(e, 'this', indent=False)})", exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", - exp.NotForReplicationColumnConstraint: lambda self, e: "", - exp.OnProperty: lambda self, e: "", - exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY", + exp.NotForReplicationColumnConstraint: lambda *_: "", + exp.OnProperty: lambda *_: "", + exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY", } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 97c891d..e549f62 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -9,7 +9,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, date_add_interval_sql, datestrtodate_sql, - format_time_lambda, + build_formatted_time, isnull_to_is_null, locate_to_strposition, max_or_greatest, @@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_tablesample_sql, no_trycast_sql, - parse_date_delta, - parse_date_delta_with_interval, + build_date_delta, + build_date_delta_with_interval, rename_func, strposition_to_locate_sql, ) @@ -39,9 +39,6 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit").upper() - if unit == "DAY": - return f"DATE({expr})" - if unit == "WEEK": concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" @@ -55,10 +52,11 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: - self.unsupported(f"Unexpected interval unit: {unit}") - return f"DATE({expr})" + if unit != "DAY": + self.unsupported(f"Unexpected interval unit: {unit}") + return self.func("DATE", expr) - return f"STR_TO_DATE({concat}, '{date_format}')" + return self.func("STR_TO_DATE", concat, f"'{date_format}'") # All specifiers for time parts (as opposed to date parts) @@ -93,8 +91,7 @@ def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime: def _str_to_date_sql( self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate ) -> str: - date_format = self.format_time(expression) - return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" + return self.func("STR_TO_DATE", expression.this, self.format_time(expression)) def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: @@ -127,9 +124,7 @@ def _date_add_sql( def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: time_format = expression.args.get("format") - if time_format: - return _str_to_date_sql(self, expression) - return f"DATE({self.sql(expression, 'this')})" + return _str_to_date_sql(self, expression) if time_format else self.func("DATE", expression.this) def _remove_ts_or_ds_to_date( @@ -289,9 +284,9 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), - "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), - "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), - "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), + "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"), + "DATE_SUB": build_date_delta_with_interval(exp.DateSub), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), @@ -306,7 +301,7 @@ class MySQL(Dialect): format=exp.Literal.string("%B"), ), "STR_TO_DATE": _str_to_date, - "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff), + "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff), "TO_DAYS": lambda args: exp.paren( exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 0)), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index de693b9..fcb3aab 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, - format_time_lambda, + build_formatted_time, no_ilike_sql, rename_func, trim_sql, @@ -18,26 +18,7 @@ if t.TYPE_CHECKING: from sqlglot._typing import E -def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable: - this = self._parse_string() - - passing = None - columns = None - - if self._match_text_seq("PASSING"): - # The BY VALUE keywords are optional and are provided for semantic clarity - self._match_text_seq("BY", "VALUE") - passing = self._parse_csv(self._parse_column) - - by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") - - if self._match_text_seq("COLUMNS"): - columns = self._parse_csv(self._parse_field_def) - - return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref) - - -def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar: +def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar: this = seq_get(args, 0) if this and not this.type: @@ -45,7 +26,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar: annotate_types(this) if this.is_type(*exp.DataType.TEMPORAL_TYPES): - return format_time_lambda(exp.TimeToStr, "oracle", default=True)(args) + return build_formatted_time(exp.TimeToStr, "oracle", default=True)(args) return exp.ToChar.from_arg_list(args) @@ -93,9 +74,9 @@ class Oracle(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), - "TO_CHAR": to_char, - "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"), - "TO_DATE": format_time_lambda(exp.StrToDate, "oracle"), + "TO_CHAR": _build_timetostr_or_tochar, + "TO_TIMESTAMP": build_formatted_time(exp.StrToTime, "oracle"), + "TO_DATE": build_formatted_time(exp.StrToDate, "oracle"), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -109,7 +90,7 @@ class Oracle(Dialect): this=self._parse_format_json(self._parse_bitwise()), order=self._parse_order(), ), - "XMLTABLE": _parse_xml_table, + "XMLTABLE": lambda self: self._parse_xml_table(), } QUERY_MODIFIER_PARSERS = { @@ -127,6 +108,26 @@ class Oracle(Dialect): # Reference: https://stackoverflow.com/a/336455 DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE} + def _parse_xml_table(self) -> exp.XMLTable: + this = self._parse_string() + + passing = None + columns = None + + if self._match_text_seq("PASSING"): + # The BY VALUE keywords are optional and are provided for semantic clarity + self._match_text_seq("BY", "VALUE") + passing = self._parse_csv(self._parse_column) + + by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") + + if self._match_text_seq("COLUMNS"): + columns = self._parse_csv(self._parse_field_def) + + return self.expression( + exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref + ) + def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E: return self.expression( expr_type, @@ -200,18 +201,17 @@ class Oracle(Dialect): transforms.eliminate_qualify, ] ), - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), + exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)), exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "), - exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.UnixToTime: lambda self, - e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", + e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 126261e..c78f8a3 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -6,10 +6,12 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( DATE_ADD_OR_SUB, Dialect, + JSON_EXTRACT_TYPE, any_value_to_max_sql, bool_xor_sql, datestrtodate_sql, - format_time_lambda, + build_formatted_time, + filter_array_using_unnest, json_extract_segments, json_path_key_only_name, max_or_greatest, @@ -20,8 +22,8 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_pivot_sql, no_trycast_sql, - parse_json_extract_path, - parse_timestamp_trunc, + build_json_extract_path, + build_timestamp_trunc, rename_func, str_position_sql, struct_extract_sql, @@ -163,7 +165,7 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression: return expression -def _generate_series(args: t.List) -> exp.Expression: +def _build_generate_series(args: t.List) -> exp.GenerateSeries: # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day step = seq_get(args, 2) @@ -179,14 +181,25 @@ def _generate_series(args: t.List) -> exp.Expression: return exp.GenerateSeries.from_arg_list(args) -def _to_timestamp(args: t.List) -> exp.Expression: +def _build_to_timestamp(args: t.List) -> exp.UnixToTime | exp.StrToTime: # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) # https://www.postgresql.org/docs/current/functions-formatting.html - return format_time_lambda(exp.StrToTime, "postgres")(args) + return build_formatted_time(exp.StrToTime, "postgres")(args) + + +def _json_extract_sql( + name: str, op: str +) -> t.Callable[[Postgres.Generator, JSON_EXTRACT_TYPE], str]: + def _generate(self: Postgres.Generator, expression: JSON_EXTRACT_TYPE) -> str: + if expression.args.get("only_json_types"): + return json_extract_segments(name, quoted_index=False, op=op)(self, expression) + return json_extract_segments(name)(self, expression) + + return _generate class Postgres(Dialect): @@ -292,19 +305,19 @@ class Postgres(Dialect): **parser.Parser.PROPERTY_PARSERS, "SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()), } - PROPERTY_PARSERS.pop("INPUT", None) + PROPERTY_PARSERS.pop("INPUT") FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE_TRUNC": parse_timestamp_trunc, - "GENERATE_SERIES": _generate_series, - "JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract), - "JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar), + "DATE_TRUNC": build_timestamp_trunc, + "GENERATE_SERIES": _build_generate_series, + "JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract), + "JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar), "MAKE_TIME": exp.TimeFromParts.from_arg_list, "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list, "NOW": exp.CurrentTimestamp.from_arg_list, - "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), - "TO_TIMESTAMP": _to_timestamp, + "TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"), + "TO_TIMESTAMP": _build_to_timestamp, "UNNEST": exp.Explode.from_arg_list, } @@ -338,6 +351,8 @@ class Postgres(Dialect): TokenType.END: lambda self: self._parse_commit_or_rollback(), } + JSON_ARROWS_REQUIRE_JSON_TYPE = True + def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: while True: if not self._match(TokenType.L_PAREN): @@ -387,6 +402,7 @@ class Postgres(Dialect): SUPPORTS_UNLOGGED_TABLES = True LIKE_PROPERTY_INSIDE_SCHEMA = True MULTI_ARG_DISTINCT = False + CAN_IMPLEMENT_ARRAY_ANY = True SUPPORTED_JSON_PATH_PARTS = { exp.JSONPathKey, @@ -416,6 +432,8 @@ class Postgres(Dialect): exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), + exp.ArrayFilter: filter_array_using_unnest, + exp.ArraySize: lambda self, e: self.func("ARRAY_LENGTH", e.this, e.expression or "1"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.CurrentDate: no_paren_current_date_sql, @@ -428,8 +446,8 @@ class Postgres(Dialect): exp.DateSub: _date_add_sql("-"), exp.Explode: rename_func("UNNEST"), exp.GroupConcat: _string_agg_sql, - exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"), - exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), + exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"), + exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"), exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), @@ -462,21 +480,20 @@ class Postgres(Dialect): ] ), exp.StrPosition: str_position_sql, - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, exp.TimeFromParts: rename_func("MAKE_TIME"), exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: _date_add_sql("+"), exp.TsOrDsDiff: _date_diff_sql, - exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", + exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this), exp.VariancePop: rename_func("VAR_POP"), exp.Variance: rename_func("VAR_SAMP"), exp.Xor: bool_xor_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 1e0e7e9..8429547 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -11,7 +11,7 @@ from sqlglot.dialects.dialect import ( date_trunc_to_time, datestrtodate_sql, encode_decode_sql, - format_time_lambda, + build_formatted_time, if_sql, left_to_substring_sql, no_ilike_sql, @@ -31,12 +31,6 @@ from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str: - accuracy = expression.args.get("accuracy") - accuracy = ", " + self.sql(accuracy) if accuracy else "" - return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" - - def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, exp.Explode): return self.sql( @@ -81,20 +75,20 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") - return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" + return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile")) def _str_to_time_sql( self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate ) -> str: - return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" + return self.func("DATE_PARSE", expression.this, self.format_time(expression)) def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): - return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") - return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") + return self.sql(exp.cast(_str_to_time_sql(self, expression), "DATE")) + return self.sql(exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: @@ -110,7 +104,7 @@ def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> st return self.func("DATE_DIFF", unit, expr, this) -def _approx_percentile(args: t.List) -> exp.Expression: +def _build_approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -125,7 +119,7 @@ def _approx_percentile(args: t.List) -> exp.Expression: return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args: t.List) -> exp.Expression: +def _build_from_unixtime(args: t.List) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -182,7 +176,7 @@ def _to_int(expression: exp.Expression) -> exp.Expression: return expression -def _parse_to_char(args: t.List) -> exp.TimeToStr: +def _build_to_char(args: t.List) -> exp.TimeToStr: fmt = seq_get(args, 1) if isinstance(fmt, exp.Literal): # We uppercase this to match Teradata's format mapping keys @@ -190,7 +184,7 @@ def _parse_to_char(args: t.List) -> exp.TimeToStr: # We use "teradata" on purpose here, because the time formats are different in Presto. # See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char - return format_time_lambda(exp.TimeToStr, "teradata")(args) + return build_formatted_time(exp.TimeToStr, "teradata")(args) class Presto(Dialect): @@ -231,7 +225,7 @@ class Presto(Dialect): **parser.Parser.FUNCTIONS, "ARBITRARY": exp.AnyValue.from_arg_list, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, - "APPROX_PERCENTILE": _approx_percentile, + "APPROX_PERCENTILE": _build_approx_percentile, "BITWISE_AND": binary_from_function(exp.BitwiseAnd), "BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)), "BITWISE_OR": binary_from_function(exp.BitwiseOr), @@ -244,14 +238,14 @@ class Presto(Dialect): "DATE_DIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), - "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), - "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), + "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "presto"), + "DATE_PARSE": build_formatted_time(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, "ELEMENT_AT": lambda args: exp.Bracket( this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True ), "FROM_HEX": exp.Unhex.from_arg_list, - "FROM_UNIXTIME": _from_unixtime, + "FROM_UNIXTIME": _build_from_unixtime, "FROM_UTF8": lambda args: exp.Decode( this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") ), @@ -271,7 +265,7 @@ class Presto(Dialect): "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), - "TO_CHAR": _parse_to_char, + "TO_CHAR": _build_to_char, "TO_HEX": exp.Hex.from_arg_list, "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, "TO_UTF8": lambda args: exp.Encode( @@ -318,35 +312,35 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: rename_func("ARBITRARY"), - exp.ApproxDistinct: _approx_distinct_sql, + exp.ApproxDistinct: lambda self, e: self.func( + "APPROX_DISTINCT", e.this, e.args.get("accuracy") + ), exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.ArrayAny: rename_func("ANY_MATCH"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), exp.ArrayUniqueAgg: rename_func("SET_AGG"), exp.AtTimeZone: rename_func("AT_TIMEZONE"), - exp.BitwiseAnd: lambda self, - e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseLeftShift: lambda self, - e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", - exp.BitwiseOr: lambda self, - e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseRightShift: lambda self, - e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseXor: lambda self, - e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression), + exp.BitwiseLeftShift: lambda self, e: self.func( + "BITWISE_ARITHMETIC_SHIFT_LEFT", e.this, e.expression + ), + exp.BitwiseNot: lambda self, e: self.func("BITWISE_NOT", e.this), + exp.BitwiseOr: lambda self, e: self.func("BITWISE_OR", e.this, e.expression), + exp.BitwiseRightShift: lambda self, e: self.func( + "BITWISE_ARITHMETIC_SHIFT_RIGHT", e.this, e.expression + ), + exp.BitwiseXor: lambda self, e: self.func("BITWISE_XOR", e.this, e.expression), exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", exp.Literal.string(e.text("unit") or "DAY"), - _to_int( - e.expression, - ), + _to_int(e.expression), e.this, ), exp.DateDiff: lambda self, e: self.func( @@ -407,21 +401,21 @@ class Presto(Dialect): exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, - exp.StrToUnix: lambda self, - e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StrToUnix: lambda self, e: self.func( + "TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e)) + ), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.Timestamp: no_timestamp_sql, exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, - e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))", - exp.TimeToStr: lambda self, - e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToUnix: lambda self, e: self.func( + "TO_UNIXTIME", self.func("DATE_PARSE", e.this, Presto.TIME_FORMAT) + ), + exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.TimeToUnix: rename_func("TO_UNIXTIME"), - exp.ToChar: lambda self, - e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 135ffc6..2201c78 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -21,15 +21,15 @@ if t.TYPE_CHECKING: from sqlglot._typing import E -def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: - def _parse_delta(args: t.List) -> E: +def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: + def _builder(args: t.List) -> E: expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) if expr_type is exp.TsOrDsAdd: expr.set("return_type", exp.DataType.build("TIMESTAMP")) return expr - return _parse_delta + return _builder class Redshift(Postgres): @@ -55,10 +55,10 @@ class Redshift(Postgres): unit=exp.var("month"), return_type=exp.DataType.build("TIMESTAMP"), ), - "DATEADD": _parse_date_delta(exp.TsOrDsAdd), - "DATE_ADD": _parse_date_delta(exp.TsOrDsAdd), - "DATEDIFF": _parse_date_delta(exp.TsOrDsDiff), - "DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff), + "DATEADD": _build_date_delta(exp.TsOrDsAdd), + "DATE_ADD": _build_date_delta(exp.TsOrDsAdd), + "DATEDIFF": _build_date_delta(exp.TsOrDsDiff), + "DATE_DIFF": _build_date_delta(exp.TsOrDsDiff), "GETDATE": exp.CurrentTimestamp.from_arg_list, "LISTAGG": exp.GroupConcat.from_arg_list, "STRTOL": exp.FromBase.from_arg_list, @@ -171,6 +171,7 @@ class Redshift(Postgres): TZ_TO_WITH_TIME_ZONE = True NVL2_SUPPORTED = True LAST_DAY_SUPPORTS_DATE_PART = False + CAN_IMPLEMENT_ARRAY_ANY = False TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -192,11 +193,12 @@ class Redshift(Postgres): ), exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), - exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", + exp.DistKeyProperty: lambda self, e: self.func("DISTKEY", e.this), exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.FromBase: rename_func("STRTOL"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), + exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), exp.GroupConcat: rename_func("LISTAGG"), exp.ParseJSON: rename_func("JSON_PARSE"), exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index b4275ea..c773e50 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import ( date_delta_sql, date_trunc_to_time, datestrtodate_sql, - format_time_lambda, + build_formatted_time, if_sql, inline_array_sql, max_or_greatest, @@ -29,12 +29,12 @@ if t.TYPE_CHECKING: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: +def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: # case: <string_expr> [ , <format> ] - return format_time_lambda(exp.StrToTime, "snowflake")(args) + return build_formatted_time(exp.StrToTime, "snowflake")(args) return exp.UnixToTime(this=first_arg, scale=second_arg) from sqlglot.optimizer.simplify import simplify_literals @@ -52,14 +52,14 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, return exp.UnixToTime.from_arg_list(args) # case: <date_expr> - return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) # case: <numeric_expr> return exp.UnixToTime.from_arg_list(args) -def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: - expression = parser.parse_var_map(args) +def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: + expression = parser.build_var_map(args) if isinstance(expression, exp.StarMap): return expression @@ -71,48 +71,14 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: ) -def _parse_datediff(args: t.List) -> exp.DateDiff: +def _build_datediff(args: t.List) -> exp.DateDiff: return exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)) ) -# https://docs.snowflake.com/en/sql-reference/functions/date_part.html -# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: - this = self._parse_var() or self._parse_type() - - if not this: - return None - - self._match(TokenType.COMMA) - expression = self._parse_bitwise() - this = _map_date_part(this) - name = this.name.upper() - - if name.startswith("EPOCH"): - if name == "EPOCH_MILLISECOND": - scale = 10**3 - elif name == "EPOCH_MICROSECOND": - scale = 10**6 - elif name == "EPOCH_NANOSECOND": - scale = 10**9 - else: - scale = None - - ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) - - if scale: - to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) - - return to_unix - - return self.expression(exp.Extract, this=this, expression=expression) - - # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.List) -> exp.If: +def _build_if_from_div0(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -120,13 +86,13 @@ def _div0_to_if(args: t.List) -> exp.If: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.List) -> exp.If: +def _build_if_from_zeroifnull(args: t.List) -> exp.If: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.List) -> exp.If: +def _build_if_from_nullifzero(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -150,13 +116,13 @@ def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> ) -def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: +def _build_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) -def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: +def _build_regexp_replace(args: t.List) -> exp.RegexpReplace: regexp_replace = exp.RegexpReplace.from_arg_list(args) if not regexp_replace.args.get("replacement"): @@ -266,38 +232,7 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: return trunc -def _parse_colon_get_path( - self: parser.Parser, this: t.Optional[exp.Expression] -) -> t.Optional[exp.Expression]: - while True: - path = self._parse_bitwise() - - # The cast :: operator has a lower precedence than the extraction operator :, so - # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH - if isinstance(path, exp.Cast): - target_type = path.to - path = path.this - else: - target_type = None - - if isinstance(path, exp.Expression): - path = exp.Literal.string(path.sql(dialect="snowflake")) - - # The extraction operator : is left-associative - this = self.expression( - exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) - ) - - if target_type: - this = exp.cast(this, target_type) - - if not self._match(TokenType.COLON): - break - - return self._parse_range(this) - - -def _parse_timestamp_from_parts(args: t.List) -> exp.Func: +def _build_timestamp_from_parts(args: t.List) -> exp.Func: if len(args) == 2: # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, # so we parse this into Anonymous for now instead of introducing complexity @@ -396,15 +331,15 @@ class Snowflake(Dialect): "BITXOR": binary_from_function(exp.BitwiseXor), "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), - "CONVERT_TIMEZONE": _parse_convert_timezone, + "CONVERT_TIMEZONE": _build_convert_timezone, "DATE_TRUNC": _date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)), ), - "DATEDIFF": _parse_datediff, - "DIV0": _div0_to_if, + "DATEDIFF": _build_datediff, + "DIV0": _build_if_from_div0, "FLATTEN": exp.Explode.from_arg_list, "GET_PATH": lambda args, dialect: exp.JSONExtract( this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) @@ -414,24 +349,24 @@ class Snowflake(Dialect): this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) ), "LISTAGG": exp.GroupConcat.from_arg_list, - "NULLIFZERO": _nullifzero_to_if, - "OBJECT_CONSTRUCT": _parse_object_construct, - "REGEXP_REPLACE": _parse_regexp_replace, + "NULLIFZERO": _build_if_from_nullifzero, + "OBJECT_CONSTRUCT": _build_object_construct, + "REGEXP_REPLACE": _build_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), - "TIMEDIFF": _parse_datediff, - "TIMESTAMPDIFF": _parse_datediff, - "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts, - "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts, - "TO_TIMESTAMP": _parse_to_timestamp, + "TIMEDIFF": _build_datediff, + "TIMESTAMPDIFF": _build_datediff, + "TIMESTAMPFROMPARTS": _build_timestamp_from_parts, + "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts, + "TO_TIMESTAMP": _build_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, - "ZEROIFNULL": _zeroifnull_to_if, + "ZEROIFNULL": _build_if_from_zeroifnull, } FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, - "DATE_PART": _parse_date_part, + "DATE_PART": lambda self: self._parse_date_part(), "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(), } FUNCTION_PARSERS.pop("TRIM") @@ -442,7 +377,7 @@ class Snowflake(Dialect): **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), - TokenType.COLON: _parse_colon_get_path, + TokenType.COLON: lambda self, this: self._parse_colon_get_path(this), } ALTER_PARSERS = { @@ -489,6 +424,69 @@ class Snowflake(Dialect): FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + def _parse_colon_get_path( + self: parser.Parser, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + while True: + path = self._parse_bitwise() + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH + if isinstance(path, exp.Cast): + target_type = path.to + path = path.this + else: + target_type = None + + if isinstance(path, exp.Expression): + path = exp.Literal.string(path.sql(dialect="snowflake")) + + # The extraction operator : is left-associative + this = self.expression( + exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + ) + + if target_type: + this = exp.cast(this, target_type) + + if not self._match(TokenType.COLON): + break + + return self._parse_range(this) + + # https://docs.snowflake.com/en/sql-reference/functions/date_part.html + # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts + def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: + this = self._parse_var() or self._parse_type() + + if not this: + return None + + self._match(TokenType.COMMA) + expression = self._parse_bitwise() + this = _map_date_part(this) + name = this.name.upper() + + if name.startswith("EPOCH"): + if name == "EPOCH_MILLISECOND": + scale = 10**3 + elif name == "EPOCH_MICROSECOND": + scale = 10**6 + elif name == "EPOCH_NANOSECOND": + scale = 10**9 + else: + scale = None + + ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) + to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) + + if scale: + to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) + + return to_unix + + return self.expression(exp.Extract, this=this, expression=expression) + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: if is_map: # Keys are strings in Snowflake's objects, see also: @@ -665,6 +663,7 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, "SQL_DOUBLE": TokenType.DOUBLE, "SQL_VARCHAR": TokenType.VARCHAR, + "STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -724,8 +723,10 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), - exp.JSONExtract: rename_func("GET_PATH"), - exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"), + exp.JSONExtract: lambda self, e: self.func("GET_PATH", e.this, e.expression), + exp.JSONExtractScalar: lambda self, e: self.func( + "JSON_EXTRACT_PATH_TEXT", e.this, e.expression + ), exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), exp.JSONPathRoot: lambda *_: "", exp.LogicalAnd: rename_func("BOOLAND_AGG"), @@ -756,8 +757,7 @@ class Snowflake(Dialect): exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.Struct: lambda self, e: self.func( "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), @@ -901,12 +901,12 @@ class Snowflake(Dialect): ) def except_op(self, expression: exp.Except) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("EXCEPT with All is not supported in Snowflake") return super().except_op(expression) def intersect_op(self, expression: exp.Intersect) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index c662ab5..20c0fce 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -4,7 +4,7 @@ import typing as t from sqlglot import exp from sqlglot.dialects.dialect import rename_func -from sqlglot.dialects.hive import _parse_ignore_nulls +from sqlglot.dialects.hive import _build_with_ignore_nulls from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get from sqlglot.transforms import ( @@ -15,7 +15,7 @@ from sqlglot.transforms import ( ) -def _parse_datediff(args: t.List) -> exp.Expression: +def _build_datediff(args: t.List) -> exp.Expression: """ Although Spark docs don't mention the "unit" argument, Spark3 added support for it at some point. Databricks also supports this variant (see below). @@ -61,8 +61,8 @@ class Spark(Spark2): class Parser(Spark2.Parser): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, - "ANY_VALUE": _parse_ignore_nulls(exp.AnyValue), - "DATEDIFF": _parse_datediff, + "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), + "DATEDIFF": _build_datediff, } def _parse_generated_as_identity( diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index fa55b51..60cf8e1 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( binary_from_function, - format_time_lambda, + build_formatted_time, is_parse_json, pivot_column_names, rename_func, @@ -26,36 +26,37 @@ def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: values = expression.args.get("values") if not keys or not values: - return "MAP()" + return self.func("MAP") - return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})" + return self.func("MAP_FROM_ARRAYS", keys, values) -def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: +def _build_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str: - this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.DATE_FORMAT: - return f"TO_DATE({this})" - return f"TO_DATE({this}, {time_format})" + return self.func("TO_DATE", expression.this) + return self.func("TO_DATE", expression.this, time_format) def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") + timestamp = expression.this + if scale is None: - return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)" + return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp")) if scale == exp.UnixToTime.SECONDS: - return f"TIMESTAMP_SECONDS({timestamp})" + return self.func("TIMESTAMP_SECONDS", timestamp) if scale == exp.UnixToTime.MILLIS: - return f"TIMESTAMP_MILLIS({timestamp})" + return self.func("TIMESTAMP_MILLIS", timestamp) if scale == exp.UnixToTime.MICROS: - return f"TIMESTAMP_MICROS({timestamp})" + return self.func("TIMESTAMP_MICROS", timestamp) - return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))" + unix_seconds = exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)) + return self.func("TIMESTAMP_SECONDS", unix_seconds) def _unalias_pivot(expression: exp.Expression) -> exp.Expression: @@ -116,16 +117,16 @@ class Spark2(Hive): **Hive.Parser.FUNCTIONS, "AGGREGATE": exp.Reduce.from_arg_list, "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "BOOLEAN": _parse_as_cast("boolean"), - "DATE": _parse_as_cast("date"), + "BOOLEAN": _build_as_cast("boolean"), + "DATE": _build_as_cast("date"), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)) ), "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DOUBLE": _parse_as_cast("double"), - "FLOAT": _parse_as_cast("float"), + "DOUBLE": _build_as_cast("double"), + "FLOAT": _build_as_cast("float"), "FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone( this=exp.cast_unless( seq_get(args, 0) or exp.Var(this=""), @@ -134,17 +135,17 @@ class Spark2(Hive): ), zone=seq_get(args, 1), ), - "INT": _parse_as_cast("int"), + "INT": _build_as_cast("int"), "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift), "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift), - "STRING": _parse_as_cast("string"), - "TIMESTAMP": _parse_as_cast("timestamp"), + "STRING": _build_as_cast("string"), + "TIMESTAMP": _build_as_cast("timestamp"), "TO_TIMESTAMP": lambda args: ( - _parse_as_cast("timestamp")(args) + _build_as_cast("timestamp")(args) if len(args) == 1 - else format_time_lambda(exp.StrToTime, "spark")(args) + else build_formatted_time(exp.StrToTime, "spark")(args) ), "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone( @@ -187,6 +188,7 @@ class Spark2(Hive): class Generator(Hive.Generator): QUERY_HINTS = True NVL2_SUPPORTED = True + CAN_IMPLEMENT_ARRAY_ANY = True PROPERTIES_LOCATION = { **Hive.Generator.PROPERTIES_LOCATION, @@ -201,8 +203,9 @@ class Spark2(Hive): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.AtTimeZone: lambda self, - e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.AtTimeZone: lambda self, e: self.func( + "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone") + ), exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), exp.Create: preprocess( @@ -221,8 +224,9 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.FromTimeZone: lambda self, - e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.FromTimeZone: lambda self, e: self.func( + "TO_UTC_TIMESTAMP", e.this, e.args.get("zone") + ), exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, @@ -236,8 +240,7 @@ class Spark2(Hive): e.args.get("position"), ), exp.StrToDate: _str_to_date, - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.TimestampTrunc: lambda self, e: self.func( "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this ), @@ -263,10 +266,7 @@ class Spark2(Hive): args = [] for arg in expression.expressions: if isinstance(arg, self.KEY_VALUE_DEFINITIONS): - if isinstance(arg, exp.Bracket): - args.append(exp.alias_(arg.this, arg.expressions[0].name)) - else: - args.append(exp.alias_(arg.expression, arg.this.name)) + args.append(exp.alias_(arg.expression, arg.this.name)) else: args.append(arg) diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 8838f34..12ac600 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -4,7 +4,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_sql, - parse_timestamp_trunc, + build_timestamp_trunc, rename_func, ) from sqlglot.dialects.mysql import MySQL @@ -15,7 +15,7 @@ class StarRocks(MySQL): class Parser(MySQL.Parser): FUNCTIONS = { **MySQL.Parser.FUNCTIONS, - "DATE_TRUNC": parse_timestamp_trunc, + "DATE_TRUNC": build_timestamp_trunc, "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), @@ -44,14 +44,12 @@ class StarRocks(MySQL): exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.RegexpLike: rename_func("REGEXP"), - exp.StrToUnix: lambda self, - e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), exp.TimestampTrunc: lambda self, e: self.func( "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this ), exp.TimeStrToDate: rename_func("TO_DATE"), - exp.UnixToStr: lambda self, - e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), exp.UnixToTime: rename_func("FROM_UNIXTIME"), } diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index e8ff249..b736918 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -34,8 +34,8 @@ class Tableau(Dialect): def count_sql(self, expression: exp.Count) -> str: this = expression.this if isinstance(this, exp.Distinct): - return f"COUNTD({self.expressions(this, flat=True)})" - return f"COUNT({self.sql(expression, 'this')})" + return self.func("COUNTD", *this.expressions) + return self.func("COUNT", this) class Parser(parser.Parser): FUNCTIONS = { diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 5b30cd4..0663a1d 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -178,6 +178,7 @@ class Teradata(Dialect): QUERY_HINTS = False TABLESAMPLE_KEYWORDS = "SAMPLE" LAST_DAY_SUPPORTS_DATE_PART = False + CAN_IMPLEMENT_ARRAY_ANY = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -195,6 +196,7 @@ class Teradata(Dialect): **generator.Generator.TRANSFORMS, exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), + exp.ArraySize: rename_func("CARDINALITY"), exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Pow: lambda self, e: self.binary(e, "**"), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 85b2e12..5955352 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import ( generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, - parse_date_delta, + build_date_delta, rename_func, timestrtotime_sql, trim_sql, @@ -64,10 +64,10 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1) BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias} -def _format_time_lambda( +def _build_formatted_time( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None ) -> t.Callable[[t.List], E]: - def _format_time(args: t.List) -> E: + def _builder(args: t.List) -> E: assert len(args) == 2 return exp_class( @@ -84,10 +84,10 @@ def _format_time_lambda( ), ) - return _format_time + return _builder -def _parse_format(args: t.List) -> exp.Expression: +def _build_format(args: t.List) -> exp.NumberToStr | exp.TimeToStr: this = seq_get(args, 0) fmt = seq_get(args, 1) culture = seq_get(args, 2) @@ -107,7 +107,7 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr(this=this, format=fmt, culture=culture) -def _parse_eomonth(args: t.List) -> exp.LastDay: +def _build_eomonth(args: t.List) -> exp.LastDay: date = exp.TsOrDsToDate(this=seq_get(args, 0)) month_lag = seq_get(args, 1) @@ -120,7 +120,7 @@ def _parse_eomonth(args: t.List) -> exp.LastDay: return exp.LastDay(this=this) -def _parse_hashbytes(args: t.List) -> exp.Expression: +def _build_hashbytes(args: t.List) -> exp.Expression: kind, data = args kind = kind.name.upper() if kind.is_string else "" @@ -179,10 +179,10 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: return f"STRING_AGG({self.format_args(this, separator)}){order}" -def _parse_date_delta( +def _build_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: - def inner_func(args: t.List) -> E: + def _builder(args: t.List) -> E: unit = seq_get(args, 0) if unit and unit_mapping: unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) @@ -204,7 +204,7 @@ def _parse_date_delta( unit=unit, ) - return inner_func + return _builder def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: @@ -242,7 +242,7 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: # https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax -def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: +def _build_datetimefromparts(args: t.List) -> exp.TimestampFromParts: return exp.TimestampFromParts( year=seq_get(args, 0), month=seq_get(args, 1), @@ -255,7 +255,7 @@ def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: # https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax -def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: +def _build_timefromparts(args: t.List) -> exp.TimeFromParts: return exp.TimeFromParts( hour=seq_get(args, 0), min=seq_get(args, 1), @@ -265,7 +265,7 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: ) -def _parse_as_text( +def _build_with_arg_as_text( klass: t.Type[exp.Expression], ) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: def _parse(args: t.List[exp.Expression]) -> exp.Expression: @@ -288,8 +288,8 @@ def _parse_as_text( def _json_extract_sql( self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar ) -> str: - json_query = rename_func("JSON_QUERY")(self, expression) - json_value = rename_func("JSON_VALUE")(self, expression) + json_query = self.func("JSON_QUERY", expression.this, expression.expression) + json_value = self.func("JSON_VALUE", expression.this, expression.expression) return self.func("ISNULL", json_query, json_value) @@ -448,28 +448,28 @@ class TSQL(Dialect): substr=seq_get(args, 0), position=seq_get(args, 2), ), - "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), - "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), - "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), - "DATEPART": _format_time_lambda(exp.TimeToStr), - "DATETIMEFROMPARTS": _parse_datetimefromparts, - "EOMONTH": _parse_eomonth, - "FORMAT": _parse_format, + "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True), + "DATEPART": _build_formatted_time(exp.TimeToStr), + "DATETIMEFROMPARTS": _build_datetimefromparts, + "EOMONTH": _build_eomonth, + "FORMAT": _build_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, - "HASHBYTES": _parse_hashbytes, + "HASHBYTES": _build_hashbytes, "ISNULL": exp.Coalesce.from_arg_list, - "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract), - "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar), - "LEN": _parse_as_text(exp.Length), - "LEFT": _parse_as_text(exp.Left), - "RIGHT": _parse_as_text(exp.Right), + "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract), + "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar), + "LEN": _build_with_arg_as_text(exp.Length), + "LEFT": _build_with_arg_as_text(exp.Left), + "RIGHT": _build_with_arg_as_text(exp.Right), "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, "SYSTEM_USER": exp.CurrentUser.from_arg_list, - "TIMEFROMPARTS": _parse_timefromparts, + "TIMEFROMPARTS": _build_timefromparts, } JOIN_HINTS = { @@ -756,6 +756,9 @@ class TSQL(Dialect): transforms.eliminate_qualify, ] ), + exp.StrPosition: lambda self, e: self.func( + "CHARINDEX", e.args.get("substr"), e.this, e.args.get("position") + ), exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( @@ -855,7 +858,7 @@ class TSQL(Dialect): return sql def create_sql(self, expression: exp.Create) -> str: - kind = self.sql(expression, "kind").upper() + kind = expression.kind exists = expression.args.pop("exists", None) sql = super().create_sql(expression) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8ef750e..1408d3c 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -67,8 +67,8 @@ class Expression(metaclass=_Expression): Attributes: key: a unique key for each class in the Expression hierarchy. This is useful for hashing and representing expressions as strings. - arg_types: determines what arguments (child nodes) are supported by an expression. It - maps arg keys to booleans that indicate whether the corresponding args are optional. + arg_types: determines the arguments (child nodes) supported by an expression. It maps + arg keys to booleans that indicate whether the corresponding args are optional. parent: a reference to the parent expression (or None, in case of root expressions). arg_key: the arg key an expression is associated with, i.e. the name its parent expression uses to refer to it. @@ -680,7 +680,7 @@ class Expression(metaclass=_Expression): *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect: the dialect used to parse the input expression. - copy: whether or not to copy the involved expressions (only applies to Expressions). + copy: whether to copy the involved expressions (only applies to Expressions). opts: other options to use to parse the input expressions. Returns: @@ -706,7 +706,7 @@ class Expression(metaclass=_Expression): *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect: the dialect used to parse the input expression. - copy: whether or not to copy the involved expressions (only applies to Expressions). + copy: whether to copy the involved expressions (only applies to Expressions). opts: other options to use to parse the input expressions. Returns: @@ -723,7 +723,7 @@ class Expression(metaclass=_Expression): 'NOT x = 1' Args: - copy: whether or not to copy this object. + copy: whether to copy this object. Returns: The new Not instance. @@ -3820,7 +3820,7 @@ class DataType(Expression): dialect: the dialect to use for parsing `dtype`, in case it's a string. udt: when set to True, `dtype` will be used as-is if it can't be parsed into a DataType, thus creating a user-defined type. - copy: whether or not to copy the data type. + copy: whether to copy the data type. kwargs: additional arguments to pass in the constructor of DataType. Returns: @@ -4309,9 +4309,9 @@ class Func(Condition): Attributes: is_var_len_args (bool): if set to True the last argument defined in arg_types will be treated as a variable length argument and the argument's value will be stored as a list. - _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items) - for this function expression. These values are used to map this node to a name during parsing - as well as to provide the function's name during SQL string generation. By default the SQL + _sql_names (list): the SQL name (1st item in the list) and aliases (subsequent items) for this + function expression. These values are used to map this node to a name during parsing as + well as to provide the function's name during SQL string generation. By default the SQL name is set to the expression's class name transformed to snake case. """ @@ -4449,6 +4449,7 @@ class ArrayAll(Func): arg_types = {"this": True, "expression": True} +# Represents Python's `any(f(x) for x in array)`, where `array` is `this` and `f` is `expression` class ArrayAny(Func): arg_types = {"this": True, "expression": True} @@ -4482,6 +4483,7 @@ class ArrayOverlaps(Binary, Func): class ArraySize(Func): arg_types = {"this": True, "expression": False} + _sql_names = ["ARRAY_SIZE", "ARRAY_LENGTH"] class ArraySort(Func): @@ -5033,7 +5035,7 @@ class JSONBContains(Binary): class JSONExtract(Binary, Func): - arg_types = {"this": True, "expression": True, "expressions": False} + arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False} _sql_names = ["JSON_EXTRACT"] is_var_len_args = True @@ -5043,7 +5045,7 @@ class JSONExtract(Binary, Func): class JSONExtractScalar(Binary, Func): - arg_types = {"this": True, "expression": True, "expressions": False} + arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False} _sql_names = ["JSON_EXTRACT_SCALAR"] is_var_len_args = True @@ -5626,7 +5628,7 @@ def maybe_parse( input expression is a SQL string). prefix: a string to prefix the sql with before it gets parsed (automatically includes a space) - copy: whether or not to copy the expression. + copy: whether to copy the expression. **opts: other options to use to parse the input expressions (again, in the case that an input expression is a SQL string). @@ -5897,7 +5899,7 @@ def union( If an `Expression` instance is passed, it will be used as-is. distinct: set the DISTINCT flag if and only if this is true. dialect: the dialect used to parse the input expression. - copy: whether or not to copy the expression. + copy: whether to copy the expression. opts: other options to use to parse the input expressions. Returns: @@ -5931,7 +5933,7 @@ def intersect( If an `Expression` instance is passed, it will be used as-is. distinct: set the DISTINCT flag if and only if this is true. dialect: the dialect used to parse the input expression. - copy: whether or not to copy the expression. + copy: whether to copy the expression. opts: other options to use to parse the input expressions. Returns: @@ -5965,7 +5967,7 @@ def except_( If an `Expression` instance is passed, it will be used as-is. distinct: set the DISTINCT flag if and only if this is true. dialect: the dialect used to parse the input expression. - copy: whether or not to copy the expression. + copy: whether to copy the expression. opts: other options to use to parse the input expressions. Returns: @@ -6127,7 +6129,7 @@ def insert( overwrite: whether to INSERT OVERWRITE or not. returning: sql conditional parsed into a RETURNING statement dialect: the dialect used to parse the input expressions. - copy: whether or not to copy the expression. + copy: whether to copy the expression. **opts: other options to use to parse the input expressions. Returns: @@ -6168,7 +6170,7 @@ def condition( If an Expression instance is passed, this is used as-is. dialect: the dialect used to parse the input expression (in the case that the input expression is a SQL string). - copy: Whether or not to copy `expression` (only applies to expressions). + copy: Whether to copy `expression` (only applies to expressions). **opts: other options to use to parse the input expressions (again, in the case that the input expression is a SQL string). @@ -6198,7 +6200,7 @@ def and_( *expressions: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect: the dialect used to parse the input expression. - copy: whether or not to copy `expressions` (only applies to Expressions). + copy: whether to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: @@ -6221,7 +6223,7 @@ def or_( *expressions: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect: the dialect used to parse the input expression. - copy: whether or not to copy `expressions` (only applies to Expressions). + copy: whether to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: @@ -6296,8 +6298,8 @@ def to_identifier(name, quoted=None, copy=True): Args: name: The name to turn into an identifier. - quoted: Whether or not force quote the identifier. - copy: Whether or not to copy name if it's an Identifier. + quoted: Whether to force quote the identifier. + copy: Whether to copy name if it's an Identifier. Returns: The identifier ast node. @@ -6379,7 +6381,7 @@ def to_table( Args: sql_path: a `[catalog].[schema].[table]` string. dialect: the source dialect according to which the table name will be parsed. - copy: Whether or not to copy a table if it is passed in. + copy: Whether to copy a table if it is passed in. kwargs: the kwargs to instantiate the resulting `Table` expression with. Returns: @@ -6418,7 +6420,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column: def alias_( expression: ExpOrStr, - alias: str | Identifier, + alias: t.Optional[str | Identifier], table: bool | t.Sequence[str | Identifier] = False, quoted: t.Optional[bool] = None, dialect: DialectType = None, @@ -6439,10 +6441,10 @@ def alias_( If an Expression instance is passed, this is used as-is. alias: the alias name to use. If the name has special characters it is quoted. - table: Whether or not to create a table alias, can also be a list of columns. - quoted: whether or not to quote the alias + table: Whether to create a table alias, can also be a list of columns. + quoted: whether to quote the alias dialect: the dialect used to parse the input expression. - copy: Whether or not to copy the expression. + copy: Whether to copy the expression. **opts: other options to use to parse the input expressions. Returns: @@ -6549,7 +6551,7 @@ def column( catalog: Catalog name. fields: Additional fields using dots. quoted: Whether to force quotes on the column's identifiers. - copy: Whether or not to copy identifiers if passed in. + copy: Whether to copy identifiers if passed in. Returns: The new Column instance. @@ -6576,7 +6578,7 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast Args: expression: The expression to cast. to: The datatype to cast to. - copy: Whether or not to copy the supplied expressions. + copy: Whether to copy the supplied expressions. Returns: The new Cast instance. @@ -6704,7 +6706,7 @@ def rename_column( table_name: Name of the table old_column: The old name of the column new_column: The new name of the column - exists: Whether or not to add the `IF EXISTS` clause + exists: Whether to add the `IF EXISTS` clause Returns: Alter table expression @@ -6727,7 +6729,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression: Args: value: A python object. - copy: Whether or not to copy `value` (only applies to Expressions and collections). + copy: Whether to copy `value` (only applies to Expressions and collections). Returns: Expression: the equivalent expression object. @@ -6847,7 +6849,7 @@ def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: Args: table: the table to normalize dialect: the dialect to use for normalization rules - copy: whether or not to copy the expression. + copy: whether to copy the expression. Examples: >>> normalize_table_name("`A-B`.c", dialect="bigquery") @@ -6872,7 +6874,7 @@ def replace_tables( expression: expression node to be transformed and replaced. mapping: mapping of table names. dialect: the dialect of the mapping table - copy: whether or not to copy the expression. + copy: whether to copy the expression. Examples: >>> from sqlglot import exp, parse_one @@ -6959,7 +6961,7 @@ def expand( expression: The expression to expand. sources: A dictionary of name to Subqueryables. dialect: The dialect of the sources dict. - copy: Whether or not to copy the expression during transformation. Defaults to True. + copy: Whether to copy the expression during transformation. Defaults to True. Returns: The transformed expression. @@ -6993,7 +6995,7 @@ def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwa Args: name: the name of the function to build. args: the args used to instantiate the function of interest. - copy: whether or not to copy the argument expressions. + copy: whether to copy the argument expressions. dialect: the source dialect. kwargs: the kwargs used to instantiate the function of interest. @@ -7096,7 +7098,7 @@ def array( Args: expressions: the expressions to add to the array. - copy: whether or not to copy the argument expressions. + copy: whether to copy the argument expressions. dialect: the source dialect. kwargs: the kwargs used to instantiate the function of interest. @@ -7123,7 +7125,7 @@ def tuple_( Args: expressions: the expressions to add to the tuple. - copy: whether or not to copy the argument expressions. + copy: whether to copy the argument expressions. dialect: the source dialect. kwargs: the kwargs used to instantiate the function of interest. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 4ff5a0e..4bb5005 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -38,19 +38,19 @@ class Generator(metaclass=_Generator): Generator converts a given syntax tree to the corresponding SQL string. Args: - pretty: Whether or not to format the produced SQL string. + pretty: Whether to format the produced SQL string. Default: False. identify: Determines when an identifier should be quoted. Possible values are: False (default): Never quote, except in cases where it's mandatory by the dialect. True or 'always': Always quote. 'safe': Only quote identifiers that are case insensitive. - normalize: Whether or not to normalize identifiers to lowercase. + normalize: Whether to normalize identifiers to lowercase. Default: False. - pad: Determines the pad size in a formatted string. + pad: The pad size in a formatted string. Default: 2. - indent: Determines the indentation size in a formatted string. + indent: The indentation size in a formatted string. Default: 2. - normalize_functions: Whether or not to normalize all function names. Possible values are: + normalize_functions: How to normalize function names. Possible values are: "upper" or True (default): Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization. @@ -59,14 +59,14 @@ class Generator(metaclass=_Generator): max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 - leading_comma: Determines whether or not the comma is leading or trailing in select expressions. + leading_comma: Whether the comma is leading or trailing in select expressions. This is only relevant when generating in pretty mode. Default: False max_text_width: The max number of characters in a segment before creating new lines in pretty mode. The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - comments: Whether or not to preserve comments in the output SQL code. + comments: Whether to preserve comments in the output SQL code. Default: True """ @@ -97,6 +97,12 @@ class Generator(metaclass=_Generator): exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}", + exp.JSONExtract: lambda self, e: self.func( + "JSON_EXTRACT", e.this, e.expression, *e.expressions + ), + exp.JSONExtractScalar: lambda self, e: self.func( + "JSON_EXTRACT_SCALAR", e.this, e.expression, *e.expressions + ), exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", @@ -134,15 +140,15 @@ class Generator(metaclass=_Generator): exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", } - # Whether or not null ordering is supported in order by + # Whether null ordering is supported in order by # True: Full Support, None: No support, False: No support in window specifications NULL_ORDERING_SUPPORTED: t.Optional[bool] = True - # Whether or not ignore nulls is inside the agg or outside. + # Whether ignore nulls is inside the agg or outside. # FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER IGNORE_NULLS_IN_FUNC = False - # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported + # Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported LOCKING_READS_SUPPORTED = False # Always do union distinct or union all @@ -151,25 +157,25 @@ class Generator(metaclass=_Generator): # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True - # Whether or not create function uses an AS before the RETURN + # Whether create function uses an AS before the RETURN CREATE_FUNCTION_RETURN_AS = True - # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed + # Whether MERGE ... WHEN MATCHED BY SOURCE is allowed MATCHED_BY_SOURCE = True - # Whether or not the INTERVAL expression works only with values like '1 day' + # Whether the INTERVAL expression works only with values like '1 day' SINGLE_STRING_INTERVAL = False - # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs + # Whether the plural form of date parts like day (i.e. "days") is supported in INTERVALs INTERVAL_ALLOWS_PLURAL_FORM = True - # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") + # Whether limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" - # Whether or not limit and fetch allows expresions or just limits + # Whether limit and fetch allows expresions or just limits LIMIT_ONLY_LITERALS = False - # Whether or not a table is allowed to be renamed with a db + # Whether a table is allowed to be renamed with a db RENAME_TABLE_WITH_DB = True # The separator for grouping sets and rollups @@ -178,105 +184,105 @@ class Generator(metaclass=_Generator): # The string used for creating an index on a table INDEX_ON = "ON" - # Whether or not join hints should be generated + # Whether join hints should be generated JOIN_HINTS = True - # Whether or not table hints should be generated + # Whether table hints should be generated TABLE_HINTS = True - # Whether or not query hints should be generated + # Whether query hints should be generated QUERY_HINTS = True # What kind of separator to use for query hints QUERY_HINT_SEP = ", " - # Whether or not comparing against booleans (e.g. x IS TRUE) is supported + # Whether comparing against booleans (e.g. x IS TRUE) is supported IS_BOOL_ALLOWED = True - # Whether or not to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement + # Whether to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement DUPLICATE_KEY_UPDATE_WITH_SET = True - # Whether or not to generate the limit as TOP <value> instead of LIMIT <value> + # Whether to generate the limit as TOP <value> instead of LIMIT <value> LIMIT_IS_TOP = False - # Whether or not to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ... + # Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ... RETURNING_END = True - # Whether or not to generate the (+) suffix for columns used in old-style join conditions + # Whether to generate the (+) suffix for columns used in old-style join conditions COLUMN_JOIN_MARKS_SUPPORTED = False - # Whether or not to generate an unquoted value for EXTRACT's date part argument + # Whether to generate an unquoted value for EXTRACT's date part argument EXTRACT_ALLOWS_QUOTES = True - # Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax + # Whether TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax TZ_TO_WITH_TIME_ZONE = False - # Whether or not the NVL2 function is supported + # Whether the NVL2 function is supported NVL2_SUPPORTED = True # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") - # Whether or not VALUES statements can be used as derived tables. + # Whether VALUES statements can be used as derived tables. # MySQL 5 and Redshift do not allow this, so when False, it will convert # SELECT * VALUES into SELECT UNION VALUES_AS_TABLE = True - # Whether or not the word COLUMN is included when adding a column with ALTER TABLE + # Whether the word COLUMN is included when adding a column with ALTER TABLE ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True # UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery) UNNEST_WITH_ORDINALITY = True - # Whether or not FILTER (WHERE cond) can be used for conditional aggregation + # Whether FILTER (WHERE cond) can be used for conditional aggregation AGGREGATE_FILTER_SUPPORTED = True - # Whether or not JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds + # Whether JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds SEMI_ANTI_JOIN_WITH_SIDE = True - # Whether or not to include the type of a computed column in the CREATE DDL + # Whether to include the type of a computed column in the CREATE DDL COMPUTED_COLUMN_WITH_TYPE = True - # Whether or not CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY + # Whether CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY SUPPORTS_TABLE_COPY = True - # Whether or not parentheses are required around the table sample's expression + # Whether parentheses are required around the table sample's expression TABLESAMPLE_REQUIRES_PARENS = True - # Whether or not a table sample clause's size needs to be followed by the ROWS keyword + # Whether a table sample clause's size needs to be followed by the ROWS keyword TABLESAMPLE_SIZE_IS_ROWS = True # The keyword(s) to use when generating a sample clause TABLESAMPLE_KEYWORDS = "TABLESAMPLE" - # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI + # Whether the TABLESAMPLE clause supports a method name, like BERNOULLI TABLESAMPLE_WITH_METHOD = True # The keyword to use when specifying the seed of a sample clause TABLESAMPLE_SEED_KEYWORD = "SEED" - # Whether or not COLLATE is a function instead of a binary operator + # Whether COLLATE is a function instead of a binary operator COLLATE_IS_FUNC = False - # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle) + # Whether data types support additional specifiers like e.g. CHAR or BYTE (oracle) DATA_TYPE_SPECIFIERS_ALLOWED = False - # Whether or not conditions require booleans WHERE x = 0 vs WHERE x + # Whether conditions require booleans WHERE x = 0 vs WHERE x ENSURE_BOOLS = False - # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs + # Whether the "RECURSIVE" keyword is required when defining recursive CTEs CTE_RECURSIVE_KEYWORD_REQUIRED = True - # Whether or not CONCAT requires >1 arguments + # Whether CONCAT requires >1 arguments SUPPORTS_SINGLE_ARG_CONCAT = True - # Whether or not LAST_DAY function supports a date part argument + # Whether LAST_DAY function supports a date part argument LAST_DAY_SUPPORTS_DATE_PART = True - # Whether or not named columns are allowed in table aliases + # Whether named columns are allowed in table aliases SUPPORTS_TABLE_ALIAS_COLUMNS = True - # Whether or not UNPIVOT aliases are Identifiers (False means they're Literals) + # Whether UNPIVOT aliases are Identifiers (False means they're Literals) UNPIVOT_ALIASES_ARE_IDENTIFIERS = True # What delimiter to use for separating JSON key/value pairs @@ -285,34 +291,37 @@ class Generator(metaclass=_Generator): # INSERT OVERWRITE TABLE x override INSERT_OVERWRITE = " OVERWRITE TABLE" - # Whether or not the SELECT .. INTO syntax is used instead of CTAS + # Whether the SELECT .. INTO syntax is used instead of CTAS SUPPORTS_SELECT_INTO = False - # Whether or not UNLOGGED tables can be created + # Whether UNLOGGED tables can be created SUPPORTS_UNLOGGED_TABLES = False - # Whether or not the CREATE TABLE LIKE statement is supported + # Whether the CREATE TABLE LIKE statement is supported SUPPORTS_CREATE_TABLE_LIKE = True - # Whether or not the LikeProperty needs to be specified inside of the schema clause + # Whether the LikeProperty needs to be specified inside of the schema clause LIKE_PROPERTY_INSIDE_SCHEMA = False - # Whether or not DISTINCT can be followed by multiple args in an AggFunc. If not, it will be + # Whether DISTINCT can be followed by multiple args in an AggFunc. If not, it will be # transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args MULTI_ARG_DISTINCT = True - # Whether or not the JSON extraction operators expect a value of type JSON + # Whether the JSON extraction operators expect a value of type JSON JSON_TYPE_REQUIRED_FOR_EXTRACTION = False - # Whether or not bracketed keys like ["foo"] are supported in JSON paths + # Whether bracketed keys like ["foo"] are supported in JSON paths JSON_PATH_BRACKETED_KEY_SUPPORTED = True - # Whether or not to escape keys using single quotes in JSON paths + # Whether to escape keys using single quotes in JSON paths JSON_PATH_SINGLE_QUOTE_ESCAPE = False # The JSONPathPart expressions supported by this dialect SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy() + # Whether any(f(x) for x in array) can be implemented by this dialect + CAN_IMPLEMENT_ARRAY_ANY = False + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -453,7 +462,7 @@ class Generator(metaclass=_Generator): # Expressions that need to have all CTEs under them bubbled up to them EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set() - KEY_VALUE_DEFINITIONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice) + KEY_VALUE_DEFINITIONS = (exp.EQ, exp.PropertyEQ, exp.Slice) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -524,7 +533,7 @@ class Generator(metaclass=_Generator): Args: expression: The syntax tree. - copy: Whether or not to copy the expression. The generator performs mutations so + copy: Whether to copy the expression. The generator performs mutations so it is safer to copy. Returns: @@ -3404,6 +3413,21 @@ class Generator(metaclass=_Generator): return self.func("LAST_DAY", expression.this) + def arrayany_sql(self, expression: exp.ArrayAny) -> str: + if self.CAN_IMPLEMENT_ARRAY_ANY: + filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression) + filtered_not_empty = exp.ArraySize(this=filtered).neq(0) + original_is_empty = exp.ArraySize(this=expression.this).eq(0) + return self.sql(exp.paren(original_is_empty.or_(filtered_not_empty))) + + from sqlglot.dialects import Dialect + + # SQLGlot's executor supports ARRAY_ANY, so we don't wanna warn for the SQLGlot dialect + if self.dialect.__class__ != Dialect: + self.unsupported("ARRAY_ANY is unsupported") + + return self.function_fallback_sql(expression) + def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: this = expression.this if isinstance(this, exp.JSONPathWildcard): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 6df36af..6bf877b 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -76,7 +76,7 @@ def normalized(expression: exp.Expression, dnf: bool = False) -> bool: Args: expression: The expression to check if it's normalized. - dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). """ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) @@ -99,7 +99,7 @@ def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int Args: expression: The expression to compute the normalization distance for. - dnf: Whether or not to check if the expression is in Disjunctive Normal Form (DNF). + dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). Returns: diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 8d83b47..e4f8b57 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -48,15 +48,15 @@ def qualify( db: Default database name for tables. catalog: Default catalog name for tables. schema: Schema to infer column names and types. - expand_alias_refs: Whether or not to expand references to aliases. - expand_stars: Whether or not to expand star queries. This is a necessary step + expand_alias_refs: Whether to expand references to aliases. + expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing! - infer_schema: Whether or not to infer the schema if missing. - isolate_tables: Whether or not to isolate table selects. - qualify_columns: Whether or not to qualify columns. - validate_qualify_columns: Whether or not to validate columns. - quote_identifiers: Whether or not to run the quote_identifiers step. + infer_schema: Whether to infer the schema if missing. + isolate_tables: Whether to isolate table selects. + qualify_columns: Whether to qualify columns. + validate_qualify_columns: Whether to validate columns. + quote_identifiers: Whether to run the quote_identifiers step. This step is necessary to ensure correctness for case sensitive queries. But this flag is provided in case this step is performed at a later time. identify: If True, quote all identifiers, else only necessary ones. diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 5c27bc3..ef589c9 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -35,11 +35,11 @@ def qualify_columns( Args: expression: Expression to qualify. schema: Database schema. - expand_alias_refs: Whether or not to expand references to aliases. - expand_stars: Whether or not to expand star queries. This is a necessary step + expand_alias_refs: Whether to expand references to aliases. + expand_stars: Whether to expand star queries. This is a necessary step for most of the optimizer's rules to work; do not set to False unless you know what you're doing! - infer_schema: Whether or not to infer the schema if missing. + infer_schema: Whether to infer the schema if missing. Returns: The qualified expression. @@ -164,12 +164,7 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: table = table or source_table conditions.append( - exp.condition( - exp.EQ( - this=exp.column(identifier, table=table), - expression=exp.column(identifier, table=join_table), - ) - ) + exp.column(identifier, table=table).eq(exp.column(identifier, table=join_table)) ) # Set all values in the dict to None, because we only care about the key ordering @@ -449,10 +444,9 @@ def _expand_stars( continue for name in columns: + if name in columns_to_exclude or name in coalesced_columns: + continue if name in using_column_tables and table in using_column_tables[name]: - if name in coalesced_columns: - continue - coalesced_columns.add(name) tables = using_column_tables[name] coalesce = [exp.column(name, table=table) for table in tables] @@ -464,7 +458,7 @@ def _expand_stars( copy=False, ) ) - elif name not in columns_to_exclude: + else: alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table=table) new_selections.append( diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 16cd548..0eae979 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -254,7 +254,7 @@ class Scope: self._columns = [] for column in columns + external_columns: ancestor = column.find_ancestor( - exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table + exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star ) if ( not ancestor diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 90357dd..9ffddb5 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -41,7 +41,7 @@ def simplify( Args: expression (sqlglot.Expression): expression to simplify - constant_propagation: whether or not the constant propagation rule should be used + constant_propagation: whether the constant propagation rule should be used Returns: sqlglot.Expression: simplified expression diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 26f4159..b4c7475 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -248,7 +248,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name): key.replace(exp.to_identifier("_x")) parent_predicate = _replace( parent_predicate, - f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', + f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", ) parent_select.join( diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 25c5789..4e7f870 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -18,7 +18,7 @@ if t.TYPE_CHECKING: logger = logging.getLogger("sqlglot") -def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: +def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap: if len(args) == 1 and args[0].is_star: return exp.StarMap(this=args[0]) @@ -28,13 +28,10 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: keys.append(args[i]) values.append(args[i + 1]) - return exp.VarMap( - keys=exp.array(*keys, copy=False), - values=exp.array(*values, copy=False), - ) + return exp.VarMap(keys=exp.array(*keys, copy=False), values=exp.array(*values, copy=False)) -def parse_like(args: t.List) -> exp.Escape | exp.Like: +def build_like(args: t.List) -> exp.Escape | exp.Like: like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like @@ -47,7 +44,7 @@ def binary_range_parser( ) -def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func: +def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: # Default argument order is base, expression this = seq_get(args, 0) expression = seq_get(args, 1) @@ -60,8 +57,8 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func: return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) -def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: - def _parser(args: t.List, dialect: Dialect) -> E: +def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: + def _builder(args: t.List, dialect: Dialect) -> E: expression = expr_type( this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) ) @@ -70,7 +67,7 @@ def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Di return expression - return _parser + return _builder class _Parser(type): @@ -90,8 +87,8 @@ class Parser(metaclass=_Parser): Args: error_level: The desired error level. Default: ErrorLevel.IMMEDIATE - error_message_context: Determines the amount of context to capture from a - query string when displaying the error message (in number of characters). + error_message_context: The amount of context to capture from a query string when displaying + the error message (in number of characters). Default: 100 max_errors: Maximum number of error messages to include in a raised ParseError. This is only relevant if error_level is ErrorLevel.RAISE. @@ -115,11 +112,11 @@ class Parser(metaclass=_Parser): to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), - "JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract), - "JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar), - "JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar), - "LIKE": parse_like, - "LOG": parse_logarithm, + "JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract), + "JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar), + "JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar), + "LIKE": build_like, + "LOG": build_logarithm, "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), @@ -132,7 +129,7 @@ class Parser(metaclass=_Parser): start=exp.Literal.number(1), length=exp.Literal.number(10), ), - "VAR_MAP": parse_var_map, + "VAR_MAP": build_var_map, } NO_PAREN_FUNCTIONS = { @@ -292,6 +289,7 @@ class Parser(metaclass=_Parser): TokenType.VIEW, TokenType.MODEL, TokenType.DICTIONARY, + TokenType.STORAGE_INTEGRATION, } CREATABLES = { @@ -550,11 +548,13 @@ class Parser(metaclass=_Parser): exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path), + only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE, ), TokenType.DARROW: lambda self, this, path: self.expression( exp.JSONExtractScalar, this=this, expression=self.dialect.to_json_path(path), + only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE, ), TokenType.HASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtract, @@ -983,28 +983,31 @@ class Parser(metaclass=_Parser): LOG_DEFAULTS_TO_LN = False - # Whether or not ADD is present for each column added by ALTER TABLE + # Whether ADD is present for each column added by ALTER TABLE ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True - # Whether or not the table sample clause expects CSV syntax + # Whether the table sample clause expects CSV syntax TABLESAMPLE_CSV = False - # Whether or not the SET command needs a delimiter (e.g. "=") for assignments + # Whether the SET command needs a delimiter (e.g. "=") for assignments SET_REQUIRES_ASSIGNMENT_DELIMITER = True # Whether the TRIM function expects the characters to trim as its first argument TRIM_PATTERN_FIRST = False - # Whether or not string aliases are supported `SELECT COUNT(*) 'count'` + # Whether string aliases are supported `SELECT COUNT(*) 'count'` STRING_ALIASES = False # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand) MODIFIERS_ATTACHED_TO_UNION = True UNION_MODIFIERS = {"order", "limit", "offset"} - # Parses no parenthesis if statements as commands + # Whether to parse IF statements that aren't followed by a left parenthesis as commands NO_PAREN_IF_COMMANDS = True + # Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres) + JSON_ARROWS_REQUIRE_JSON_TYPE = False + # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause. # If this is True and '(' is not found, the keyword will be treated as an identifier VALUES_FOLLOWED_BY_PAREN = True diff --git a/sqlglot/schema.py b/sqlglot/schema.py index dbd0caa..36022b9 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -92,7 +92,7 @@ class Schema(abc.ABC): normalize: t.Optional[bool] = None, ) -> bool: """ - Returns whether or not `column` appears in `table`'s schema. + Returns whether `column` appears in `table`'s schema. Args: table: the source table. @@ -115,7 +115,7 @@ class Schema(abc.ABC): @property def empty(self) -> bool: - """Returns whether or not the schema is empty.""" + """Returns whether the schema is empty.""" return True @@ -162,7 +162,7 @@ class AbstractMappingSchema: Args: table: the target table. - raise_on_missing: whether or not to raise in case the schema is not found. + raise_on_missing: whether to raise in case the schema is not found. Returns: The schema of the target table. diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 2cfcfa6..939ca18 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -346,6 +346,7 @@ class TokenType(AutoName): SOME = auto() SORT_BY = auto() START_WITH = auto() + STORAGE_INTEGRATION = auto() STRUCT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() @@ -577,7 +578,7 @@ class Tokenizer(metaclass=_Tokenizer): STRING_ESCAPES = ["'"] VAR_SINGLE_TOKENS: t.Set[str] = set() - # Whether or not the heredoc tags follow the same lexical rules as unquoted identifiers + # Whether the heredoc tags follow the same lexical rules as unquoted identifiers HEREDOC_TAG_IS_IDENTIFIER = False # Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc |