diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:11:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:12:02 +0000 |
commit | 8d36f5966675e23bee7026ba37ae0647fbf47300 (patch) | |
tree | df4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/dialects/snowflake.py | |
parent | Releasing debian version 22.2.0-1. (diff) | |
download | sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip |
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 201 |
1 files changed, 140 insertions, 61 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 20fdfb7..73a9166 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, var_map_sql, ) -from sqlglot.expressions import Literal -from sqlglot.helper import flatten, is_int, seq_get +from sqlglot.helper import flatten, is_float, is_int, seq_get from sqlglot.tokens import TokenType if t.TYPE_CHECKING: @@ -29,33 +28,35 @@ if t.TYPE_CHECKING: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: - if len(args) == 2: - first_arg, second_arg = args - if second_arg.is_string: - # case: <string_expr> [ , <format> ] - return build_formatted_time(exp.StrToTime, "snowflake")(args) - return exp.UnixToTime(this=first_arg, scale=second_arg) +def _build_datetime( + name: str, kind: exp.DataType.Type, safe: bool = False +) -> t.Callable[[t.List], exp.Func]: + def _builder(args: t.List) -> exp.Func: + value = seq_get(args, 0) + + if isinstance(value, exp.Literal): + int_value = is_int(value.this) - from sqlglot.optimizer.simplify import simplify_literals + # Converts calls like `TO_TIME('01:02:03')` into casts + if len(args) == 1 and value.is_string and not int_value: + return exp.cast(value, kind) - # The first argument might be an expression like 40 * 365 * 86400, so we try to - # reduce it using `simplify_literals` first and then check if it's a Literal. - first_arg = seq_get(args, 0) - if not isinstance(simplify_literals(first_arg, root=True), Literal): - # case: <variant_expr> or other expressions such as columns - return exp.TimeStrToTime.from_arg_list(args) + # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special + # cases so we can transpile them, since they're relatively common + if kind == exp.DataType.Type.TIMESTAMP: + if int_value: + return exp.UnixToTime(this=value, scale=seq_get(args, 1)) + if not is_float(value.this): + return build_formatted_time(exp.StrToTime, "snowflake")(args) - if first_arg.is_string: - if is_int(first_arg.this): - # case: <integer> - return exp.UnixToTime.from_arg_list(args) + if len(args) == 2 and kind == exp.DataType.Type.DATE: + formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args) + formatted_exp.set("safe", safe) + return formatted_exp - # case: <date_expr> - return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) + return exp.Anonymous(this=name, expressions=args) - # case: <numeric_expr> - return exp.UnixToTime.from_arg_list(args) + return _builder def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: @@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff: ) +def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: + def _builder(args: t.List) -> E: + return expr_type( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=_map_date_part(seq_get(args, 0)), + ) + + return _builder + + # https://docs.snowflake.com/en/sql-reference/functions/div0 def _build_if_from_div0(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) @@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If: return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str: - if expression.is_type("array"): - return "ARRAY" - elif expression.is_type("map"): - return "OBJECT" - return self.datatype_sql(expression) - - def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: flag = expression.text("flag") @@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + + def _flatten_structured_type(expression: exp.DataType) -> exp.DataType: + if expression.this in exp.DataType.NESTED_TYPES: + expression.set("expressions", None) + return expression + + props = expression.args.get("properties") + if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)): + for schema_expression in expression.this.expressions: + if isinstance(schema_expression, exp.ColumnDef): + column_type = schema_expression.kind + if isinstance(column_type, exp.DataType): + column_type.transform(_flatten_structured_type, copy=False) + + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -312,7 +335,13 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + ID_VAR_TOKENS = { + *parser.Parser.ID_VAR_TOKENS, + TokenType.MATCH_CONDITION, + } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW} + TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION) FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -327,17 +356,13 @@ class Snowflake(Dialect): end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)), step=seq_get(args, 2), ), - "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "BITXOR": binary_from_function(exp.BitwiseXor), "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _build_convert_timezone, + "DATE": _build_datetime("DATE", exp.DataType.Type.DATE), "DATE_TRUNC": _date_trunc_to_time, - "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=_map_date_part(seq_get(args, 0)), - ), + "DATEADD": _build_date_time_add(exp.DateAdd), "DATEDIFF": _build_datediff, "DIV0": _build_if_from_div0, "FLATTEN": exp.Explode.from_arg_list, @@ -349,17 +374,34 @@ class Snowflake(Dialect): this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) ), "LISTAGG": exp.GroupConcat.from_arg_list, + "MEDIAN": lambda args: exp.PercentileCont( + this=seq_get(args, 0), expression=exp.Literal.number(0.5) + ), "NULLIFZERO": _build_if_from_nullifzero, "OBJECT_CONSTRUCT": _build_object_construct, "REGEXP_REPLACE": _build_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), + "TIMEADD": _build_date_time_add(exp.TimeAdd), "TIMEDIFF": _build_datediff, + "TIMESTAMPADD": _build_date_time_add(exp.DateAdd), "TIMESTAMPDIFF": _build_datediff, "TIMESTAMPFROMPARTS": _build_timestamp_from_parts, "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts, - "TO_TIMESTAMP": _build_to_timestamp, + "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True), + "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE), + "TO_NUMBER": lambda args: exp.ToNumber( + this=seq_get(args, 0), + format=seq_get(args, 1), + precision=seq_get(args, 2), + scale=seq_get(args, 3), + ), + "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME), + "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP), + "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ), + "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP), + "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ), "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _build_if_from_zeroifnull, } @@ -377,7 +419,6 @@ class Snowflake(Dialect): **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), - TokenType.COLON: lambda self, this: self._parse_colon_get_path(this), } ALTER_PARSERS = { @@ -434,35 +475,35 @@ class Snowflake(Dialect): SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"} - def _parse_colon_get_path( - self: parser.Parser, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - while True: - path = self._parse_bitwise() or self._parse_var(any_token=True) + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + this = super()._parse_column_ops(this) + + casts = [] + json_path = [] + + while self._match(TokenType.COLON): + path = super()._parse_column_ops(self._parse_field(any_token=True)) # The cast :: operator has a lower precedence than the extraction operator :, so # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH - if isinstance(path, exp.Cast): - target_type = path.to + while isinstance(path, exp.Cast): + casts.append(path.to) path = path.this - else: - target_type = None - if isinstance(path, exp.Expression): - path = exp.Literal.string(path.sql(dialect="snowflake")) + if path: + json_path.append(path.sql(dialect="snowflake", copy=False)) - # The extraction operator : is left-associative + if json_path: this = self.expression( - exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + exp.JSONExtract, + this=this, + expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))), ) - if target_type: - this = exp.cast(this, target_type) + while casts: + this = self.expression(exp.Cast, this=this, to=casts.pop()) - if not self._match(TokenType.COLON): - break - - return self._parse_range(this) + return this # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts @@ -663,6 +704,7 @@ class Snowflake(Dialect): "EXCLUDE": TokenType.EXCEPT, "ILIKE ANY": TokenType.ILIKE_ANY, "LIKE ANY": TokenType.LIKE_ANY, + "MATCH_CONDITION": TokenType.MATCH_CONDITION, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NCHAR VARYING": TokenType.VARCHAR, @@ -703,6 +745,7 @@ class Snowflake(Dialect): LIMIT_ONLY_LITERALS = True JSON_KEY_VALUE_PAIR_SEP = "," INSERT_OVERWRITE = " OVERWRITE INTO" + STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -711,15 +754,14 @@ class Snowflake(Dialect): exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this), - exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), exp.BitwiseXor: rename_func("BITXOR"), + exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]), exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DateStrToDate: datestrtodate_sql, - exp.DataType: _datatype_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -769,6 +811,7 @@ class Snowflake(Dialect): ), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.Stuff: rename_func("INSERT"), + exp.TimeAdd: date_delta_sql("TIMEADD"), exp.TimestampDiff: lambda self, e: self.func( "TIMESTAMPDIFF", e.unit, e.expression, e.this ), @@ -783,6 +826,9 @@ class Snowflake(Dialect): exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), + exp.TsOrDsToDate: lambda self, e: self.func( + "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e) + ), exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), @@ -797,6 +843,8 @@ class Snowflake(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.NESTED: "OBJECT", + exp.DataType.Type.STRUCT: "OBJECT", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -811,6 +859,37 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + UNSUPPORTED_VALUES_EXPRESSIONS = { + exp.Struct, + } + + def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: + if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS): + values_as_table = False + + return super().values_sql(expression, values_as_table=values_as_table) + + def datatype_sql(self, expression: exp.DataType) -> str: + expressions = expression.expressions + if ( + expressions + and expression.is_type(*exp.DataType.STRUCT_TYPES) + and any(isinstance(field_type, exp.DataType) for field_type in expressions) + ): + # The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ] + return "OBJECT" + + return super().datatype_sql(expression) + + def tonumber_sql(self, expression: exp.ToNumber) -> str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("precision"), + expression.args.get("scale"), + ) + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: milli = expression.args.get("milli") if milli is not None: |