diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 196 |
1 files changed, 98 insertions, 98 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index b4275ea..c773e50 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import ( date_delta_sql, date_trunc_to_time, datestrtodate_sql, - format_time_lambda, + build_formatted_time, if_sql, inline_array_sql, max_or_greatest, @@ -29,12 +29,12 @@ if t.TYPE_CHECKING: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: +def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: # case: <string_expr> [ , <format> ] - return format_time_lambda(exp.StrToTime, "snowflake")(args) + return build_formatted_time(exp.StrToTime, "snowflake")(args) return exp.UnixToTime(this=first_arg, scale=second_arg) from sqlglot.optimizer.simplify import simplify_literals @@ -52,14 +52,14 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, return exp.UnixToTime.from_arg_list(args) # case: <date_expr> - return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) # case: <numeric_expr> return exp.UnixToTime.from_arg_list(args) -def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: - expression = parser.parse_var_map(args) +def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: + expression = parser.build_var_map(args) if isinstance(expression, exp.StarMap): return expression @@ -71,48 +71,14 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: ) -def _parse_datediff(args: t.List) -> exp.DateDiff: +def _build_datediff(args: t.List) -> exp.DateDiff: return exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)) ) -# https://docs.snowflake.com/en/sql-reference/functions/date_part.html -# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: - this = self._parse_var() or self._parse_type() - - if not this: - return None - - self._match(TokenType.COMMA) - expression = self._parse_bitwise() - this = _map_date_part(this) - name = this.name.upper() - - if name.startswith("EPOCH"): - if name == "EPOCH_MILLISECOND": - scale = 10**3 - elif name == "EPOCH_MICROSECOND": - scale = 10**6 - elif name == "EPOCH_NANOSECOND": - scale = 10**9 - else: - scale = None - - ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) - - if scale: - to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) - - return to_unix - - return self.expression(exp.Extract, this=this, expression=expression) - - # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.List) -> exp.If: +def _build_if_from_div0(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -120,13 +86,13 @@ def _div0_to_if(args: t.List) -> exp.If: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.List) -> exp.If: +def _build_if_from_zeroifnull(args: t.List) -> exp.If: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.List) -> exp.If: +def _build_if_from_nullifzero(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -150,13 +116,13 @@ def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> ) -def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: +def _build_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) -def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: +def _build_regexp_replace(args: t.List) -> exp.RegexpReplace: regexp_replace = exp.RegexpReplace.from_arg_list(args) if not regexp_replace.args.get("replacement"): @@ -266,38 +232,7 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: return trunc -def _parse_colon_get_path( - self: parser.Parser, this: t.Optional[exp.Expression] -) -> t.Optional[exp.Expression]: - while True: - path = self._parse_bitwise() - - # The cast :: operator has a lower precedence than the extraction operator :, so - # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH - if isinstance(path, exp.Cast): - target_type = path.to - path = path.this - else: - target_type = None - - if isinstance(path, exp.Expression): - path = exp.Literal.string(path.sql(dialect="snowflake")) - - # The extraction operator : is left-associative - this = self.expression( - exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) - ) - - if target_type: - this = exp.cast(this, target_type) - - if not self._match(TokenType.COLON): - break - - return self._parse_range(this) - - -def _parse_timestamp_from_parts(args: t.List) -> exp.Func: +def _build_timestamp_from_parts(args: t.List) -> exp.Func: if len(args) == 2: # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, # so we parse this into Anonymous for now instead of introducing complexity @@ -396,15 +331,15 @@ class Snowflake(Dialect): "BITXOR": binary_from_function(exp.BitwiseXor), "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), - "CONVERT_TIMEZONE": _parse_convert_timezone, + "CONVERT_TIMEZONE": _build_convert_timezone, "DATE_TRUNC": _date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)), ), - "DATEDIFF": _parse_datediff, - "DIV0": _div0_to_if, + "DATEDIFF": _build_datediff, + "DIV0": _build_if_from_div0, "FLATTEN": exp.Explode.from_arg_list, "GET_PATH": lambda args, dialect: exp.JSONExtract( this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) @@ -414,24 +349,24 @@ class Snowflake(Dialect): this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) ), "LISTAGG": exp.GroupConcat.from_arg_list, - "NULLIFZERO": _nullifzero_to_if, - "OBJECT_CONSTRUCT": _parse_object_construct, - "REGEXP_REPLACE": _parse_regexp_replace, + "NULLIFZERO": _build_if_from_nullifzero, + "OBJECT_CONSTRUCT": _build_object_construct, + "REGEXP_REPLACE": _build_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), - "TIMEDIFF": _parse_datediff, - "TIMESTAMPDIFF": _parse_datediff, - "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts, - "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts, - "TO_TIMESTAMP": _parse_to_timestamp, + "TIMEDIFF": _build_datediff, + "TIMESTAMPDIFF": _build_datediff, + "TIMESTAMPFROMPARTS": _build_timestamp_from_parts, + "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts, + "TO_TIMESTAMP": _build_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, - "ZEROIFNULL": _zeroifnull_to_if, + "ZEROIFNULL": _build_if_from_zeroifnull, } FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, - "DATE_PART": _parse_date_part, + "DATE_PART": lambda self: self._parse_date_part(), "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(), } FUNCTION_PARSERS.pop("TRIM") @@ -442,7 +377,7 @@ class Snowflake(Dialect): **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), - TokenType.COLON: _parse_colon_get_path, + TokenType.COLON: lambda self, this: self._parse_colon_get_path(this), } ALTER_PARSERS = { @@ -489,6 +424,69 @@ class Snowflake(Dialect): FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + def _parse_colon_get_path( + self: parser.Parser, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + while True: + path = self._parse_bitwise() + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH + if isinstance(path, exp.Cast): + target_type = path.to + path = path.this + else: + target_type = None + + if isinstance(path, exp.Expression): + path = exp.Literal.string(path.sql(dialect="snowflake")) + + # The extraction operator : is left-associative + this = self.expression( + exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + ) + + if target_type: + this = exp.cast(this, target_type) + + if not self._match(TokenType.COLON): + break + + return self._parse_range(this) + + # https://docs.snowflake.com/en/sql-reference/functions/date_part.html + # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts + def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: + this = self._parse_var() or self._parse_type() + + if not this: + return None + + self._match(TokenType.COMMA) + expression = self._parse_bitwise() + this = _map_date_part(this) + name = this.name.upper() + + if name.startswith("EPOCH"): + if name == "EPOCH_MILLISECOND": + scale = 10**3 + elif name == "EPOCH_MICROSECOND": + scale = 10**6 + elif name == "EPOCH_NANOSECOND": + scale = 10**9 + else: + scale = None + + ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) + to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) + + if scale: + to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) + + return to_unix + + return self.expression(exp.Extract, this=this, expression=expression) + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: if is_map: # Keys are strings in Snowflake's objects, see also: @@ -665,6 +663,7 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, "SQL_DOUBLE": TokenType.DOUBLE, "SQL_VARCHAR": TokenType.VARCHAR, + "STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -724,8 +723,10 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), - exp.JSONExtract: rename_func("GET_PATH"), - exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"), + exp.JSONExtract: lambda self, e: self.func("GET_PATH", e.this, e.expression), + exp.JSONExtractScalar: lambda self, e: self.func( + "JSON_EXTRACT_PATH_TEXT", e.this, e.expression + ), exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), exp.JSONPathRoot: lambda *_: "", exp.LogicalAnd: rename_func("BOOLAND_AGG"), @@ -756,8 +757,7 @@ class Snowflake(Dialect): exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), - exp.StrToTime: lambda self, - e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.Struct: lambda self, e: self.func( "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), @@ -901,12 +901,12 @@ class Snowflake(Dialect): ) def except_op(self, expression: exp.Except) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("EXCEPT with All is not supported in Snowflake") return super().except_op(expression) def intersect_op(self, expression: exp.Intersect) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) |