summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r--sqlglot/dialects/snowflake.py196
1 files changed, 98 insertions, 98 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index b4275ea..c773e50 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
date_delta_sql,
date_trunc_to_time,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
if_sql,
inline_array_sql,
max_or_greatest,
@@ -29,12 +29,12 @@ if t.TYPE_CHECKING:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
+def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
if len(args) == 2:
first_arg, second_arg = args
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
- return format_time_lambda(exp.StrToTime, "snowflake")(args)
+ return build_formatted_time(exp.StrToTime, "snowflake")(args)
return exp.UnixToTime(this=first_arg, scale=second_arg)
from sqlglot.optimizer.simplify import simplify_literals
@@ -52,14 +52,14 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
return exp.UnixToTime.from_arg_list(args)
# case: <date_expr>
- return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
+ return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
# case: <numeric_expr>
return exp.UnixToTime.from_arg_list(args)
-def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
- expression = parser.parse_var_map(args)
+def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
+ expression = parser.build_var_map(args)
if isinstance(expression, exp.StarMap):
return expression
@@ -71,48 +71,14 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
)
-def _parse_datediff(args: t.List) -> exp.DateDiff:
+def _build_datediff(args: t.List) -> exp.DateDiff:
return exp.DateDiff(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
)
-# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
-# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
-def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
- this = self._parse_var() or self._parse_type()
-
- if not this:
- return None
-
- self._match(TokenType.COMMA)
- expression = self._parse_bitwise()
- this = _map_date_part(this)
- name = this.name.upper()
-
- if name.startswith("EPOCH"):
- if name == "EPOCH_MILLISECOND":
- scale = 10**3
- elif name == "EPOCH_MICROSECOND":
- scale = 10**6
- elif name == "EPOCH_NANOSECOND":
- scale = 10**9
- else:
- scale = None
-
- ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
- to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
-
- if scale:
- to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
-
- return to_unix
-
- return self.expression(exp.Extract, this=this, expression=expression)
-
-
# https://docs.snowflake.com/en/sql-reference/functions/div0
-def _div0_to_if(args: t.List) -> exp.If:
+def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
true = exp.Literal.number(0)
false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -120,13 +86,13 @@ def _div0_to_if(args: t.List) -> exp.If:
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _zeroifnull_to_if(args: t.List) -> exp.If:
+def _build_if_from_zeroifnull(args: t.List) -> exp.If:
cond = exp.Is(this=seq_get(args, 0), expression=exp.Null())
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
-def _nullifzero_to_if(args: t.List) -> exp.If:
+def _build_if_from_nullifzero(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
@@ -150,13 +116,13 @@ def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) ->
)
-def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
+def _build_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]:
if len(args) == 3:
return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args)
return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0))
-def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace:
+def _build_regexp_replace(args: t.List) -> exp.RegexpReplace:
regexp_replace = exp.RegexpReplace.from_arg_list(args)
if not regexp_replace.args.get("replacement"):
@@ -266,38 +232,7 @@ def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
return trunc
-def _parse_colon_get_path(
- self: parser.Parser, this: t.Optional[exp.Expression]
-) -> t.Optional[exp.Expression]:
- while True:
- path = self._parse_bitwise()
-
- # The cast :: operator has a lower precedence than the extraction operator :, so
- # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
- if isinstance(path, exp.Cast):
- target_type = path.to
- path = path.this
- else:
- target_type = None
-
- if isinstance(path, exp.Expression):
- path = exp.Literal.string(path.sql(dialect="snowflake"))
-
- # The extraction operator : is left-associative
- this = self.expression(
- exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
- )
-
- if target_type:
- this = exp.cast(this, target_type)
-
- if not self._match(TokenType.COLON):
- break
-
- return self._parse_range(this)
-
-
-def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
+def _build_timestamp_from_parts(args: t.List) -> exp.Func:
if len(args) == 2:
# Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
# so we parse this into Anonymous for now instead of introducing complexity
@@ -396,15 +331,15 @@ class Snowflake(Dialect):
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
- "CONVERT_TIMEZONE": _parse_convert_timezone,
+ "CONVERT_TIMEZONE": _build_convert_timezone,
"DATE_TRUNC": _date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
unit=_map_date_part(seq_get(args, 0)),
),
- "DATEDIFF": _parse_datediff,
- "DIV0": _div0_to_if,
+ "DATEDIFF": _build_datediff,
+ "DIV0": _build_if_from_div0,
"FLATTEN": exp.Explode.from_arg_list,
"GET_PATH": lambda args, dialect: exp.JSONExtract(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
@@ -414,24 +349,24 @@ class Snowflake(Dialect):
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
- "NULLIFZERO": _nullifzero_to_if,
- "OBJECT_CONSTRUCT": _parse_object_construct,
- "REGEXP_REPLACE": _parse_regexp_replace,
+ "NULLIFZERO": _build_if_from_nullifzero,
+ "OBJECT_CONSTRUCT": _build_object_construct,
+ "REGEXP_REPLACE": _build_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
- "TIMEDIFF": _parse_datediff,
- "TIMESTAMPDIFF": _parse_datediff,
- "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
- "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
- "TO_TIMESTAMP": _parse_to_timestamp,
+ "TIMEDIFF": _build_datediff,
+ "TIMESTAMPDIFF": _build_datediff,
+ "TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
+ "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
+ "TO_TIMESTAMP": _build_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
- "ZEROIFNULL": _zeroifnull_to_if,
+ "ZEROIFNULL": _build_if_from_zeroifnull,
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
- "DATE_PART": _parse_date_part,
+ "DATE_PART": lambda self: self._parse_date_part(),
"OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
}
FUNCTION_PARSERS.pop("TRIM")
@@ -442,7 +377,7 @@ class Snowflake(Dialect):
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
- TokenType.COLON: _parse_colon_get_path,
+ TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
}
ALTER_PARSERS = {
@@ -489,6 +424,69 @@ class Snowflake(Dialect):
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
+ def _parse_colon_get_path(
+ self: parser.Parser, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ while True:
+ path = self._parse_bitwise()
+
+ # The cast :: operator has a lower precedence than the extraction operator :, so
+ # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
+ if isinstance(path, exp.Cast):
+ target_type = path.to
+ path = path.this
+ else:
+ target_type = None
+
+ if isinstance(path, exp.Expression):
+ path = exp.Literal.string(path.sql(dialect="snowflake"))
+
+ # The extraction operator : is left-associative
+ this = self.expression(
+ exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
+ )
+
+ if target_type:
+ this = exp.cast(this, target_type)
+
+ if not self._match(TokenType.COLON):
+ break
+
+ return self._parse_range(this)
+
+ # https://docs.snowflake.com/en/sql-reference/functions/date_part.html
+ # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
+ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
+ this = self._parse_var() or self._parse_type()
+
+ if not this:
+ return None
+
+ self._match(TokenType.COMMA)
+ expression = self._parse_bitwise()
+ this = _map_date_part(this)
+ name = this.name.upper()
+
+ if name.startswith("EPOCH"):
+ if name == "EPOCH_MILLISECOND":
+ scale = 10**3
+ elif name == "EPOCH_MICROSECOND":
+ scale = 10**6
+ elif name == "EPOCH_NANOSECOND":
+ scale = 10**9
+ else:
+ scale = None
+
+ ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP"))
+ to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts)
+
+ if scale:
+ to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale))
+
+ return to_unix
+
+ return self.expression(exp.Extract, this=this, expression=expression)
+
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
if is_map:
# Keys are strings in Snowflake's objects, see also:
@@ -665,6 +663,7 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
"SQL_DOUBLE": TokenType.DOUBLE,
"SQL_VARCHAR": TokenType.VARCHAR,
+ "STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -724,8 +723,10 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
- exp.JSONExtract: rename_func("GET_PATH"),
- exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"),
+ exp.JSONExtract: lambda self, e: self.func("GET_PATH", e.this, e.expression),
+ exp.JSONExtractScalar: lambda self, e: self.func(
+ "JSON_EXTRACT_PATH_TEXT", e.this, e.expression
+ ),
exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
exp.JSONPathRoot: lambda *_: "",
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
@@ -756,8 +757,7 @@ class Snowflake(Dialect):
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Struct: lambda self, e: self.func(
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
@@ -901,12 +901,12 @@ class Snowflake(Dialect):
)
def except_op(self, expression: exp.Except) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("EXCEPT with All is not supported in Snowflake")
return super().except_op(expression)
def intersect_op(self, expression: exp.Intersect) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("INTERSECT with All is not supported in Snowflake")
return super().intersect_op(expression)