diff options
Diffstat (limited to 'sqlglot/dialects/snowflake.py')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 266 |
1 files changed, 216 insertions, 50 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 8925181..a8e4a42 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import ( rename_func, timestamptrunc_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, var_map_sql, ) from sqlglot.expressions import Literal @@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, if second_arg.is_string: # case: <string_expr> [ , <format> ] return format_time_lambda(exp.StrToTime, "snowflake")(args) - - # case: <numeric_expr> [ , <scale> ] - if second_arg.name not in ["0", "3", "9"]: - raise ValueError( - f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" - ) - - if second_arg.name == "0": - timescale = exp.UnixToTime.SECONDS - elif second_arg.name == "3": - timescale = exp.UnixToTime.MILLIS - elif second_arg.name == "9": - timescale = exp.UnixToTime.NANOS - - return exp.UnixToTime(this=first_arg, scale=timescale) + return exp.UnixToTime(this=first_arg, scale=second_arg) from sqlglot.optimizer.simplify import simplify_literals @@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: def _parse_datediff(args: t.List) -> exp.DateDiff: - return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) - - -def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") - if scale in (None, exp.UnixToTime.SECONDS): - return f"TO_TIMESTAMP({timestamp})" - if scale == exp.UnixToTime.MILLIS: - return f"TO_TIMESTAMP({timestamp}, 3)" - if scale == exp.UnixToTime.MICROS: - return f"TO_TIMESTAMP({timestamp} / 1000, 3)" - if scale == exp.UnixToTime.NANOS: - return f"TO_TIMESTAMP({timestamp}, 9)" - - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)) + ) # https://docs.snowflake.com/en/sql-reference/functions/date_part.html @@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: self._match(TokenType.COMMA) expression = self._parse_bitwise() - + this = _map_date_part(this) name = this.name.upper() + if name.startswith("EPOCH"): - if name.startswith("EPOCH_MILLISECOND"): + if name == "EPOCH_MILLISECOND": scale = 10**3 - elif name.startswith("EPOCH_MICROSECOND"): + elif name == "EPOCH_MICROSECOND": scale = 10**6 - elif name.startswith("EPOCH_NANOSECOND"): + elif name == "EPOCH_NANOSECOND": scale = 10**9 else: scale = None @@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser] return _parse +DATE_PART_MAPPING = { + "Y": "YEAR", + "YY": "YEAR", + "YYY": "YEAR", + "YYYY": "YEAR", + "YR": "YEAR", + "YEARS": "YEAR", + "YRS": "YEAR", + "MM": "MONTH", + "MON": "MONTH", + "MONS": "MONTH", + "MONTHS": "MONTH", + "D": "DAY", + "DD": "DAY", + "DAYS": "DAY", + "DAYOFMONTH": "DAY", + "WEEKDAY": "DAYOFWEEK", + "DOW": "DAYOFWEEK", + "DW": "DAYOFWEEK", + "WEEKDAY_ISO": "DAYOFWEEKISO", + "DOW_ISO": "DAYOFWEEKISO", + "DW_ISO": "DAYOFWEEKISO", + "YEARDAY": "DAYOFYEAR", + "DOY": "DAYOFYEAR", + "DY": "DAYOFYEAR", + "W": "WEEK", + "WK": "WEEK", + "WEEKOFYEAR": "WEEK", + "WOY": "WEEK", + "WY": "WEEK", + "WEEK_ISO": "WEEKISO", + "WEEKOFYEARISO": "WEEKISO", + "WEEKOFYEAR_ISO": "WEEKISO", + "Q": "QUARTER", + "QTR": "QUARTER", + "QTRS": "QUARTER", + "QUARTERS": "QUARTER", + "H": "HOUR", + "HH": "HOUR", + "HR": "HOUR", + "HOURS": "HOUR", + "HRS": "HOUR", + "M": "MINUTE", + "MI": "MINUTE", + "MIN": "MINUTE", + "MINUTES": "MINUTE", + "MINS": "MINUTE", + "S": "SECOND", + "SEC": "SECOND", + "SECONDS": "SECOND", + "SECS": "SECOND", + "MS": "MILLISECOND", + "MSEC": "MILLISECOND", + "MILLISECONDS": "MILLISECOND", + "US": "MICROSECOND", + "USEC": "MICROSECOND", + "MICROSECONDS": "MICROSECOND", + "NS": "NANOSECOND", + "NSEC": "NANOSECOND", + "NANOSEC": "NANOSECOND", + "NSECOND": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "NANOSECS": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "EPOCH": "EPOCH_SECOND", + "EPOCH_SECONDS": "EPOCH_SECOND", + "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", + "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", + "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", + "TZH": "TIMEZONE_HOUR", + "TZM": "TIMEZONE_MINUTE", +} + + +@t.overload +def _map_date_part(part: exp.Expression) -> exp.Var: + pass + + +@t.overload +def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + pass + + +def _map_date_part(part): + mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None + return exp.var(mapped) if mapped else part + + +def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: + trunc = date_trunc_to_time(args) + trunc.set("unit", _map_date_part(trunc.args["unit"])) + return trunc + + +def _parse_colon_get_path( + self: parser.Parser, this: t.Optional[exp.Expression] +) -> t.Optional[exp.Expression]: + while True: + path = self._parse_bitwise() + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH + if isinstance(path, exp.Cast): + target_type = path.to + path = path.this + else: + target_type = None + + if isinstance(path, exp.Expression): + path = exp.Literal.string(path.sql(dialect="snowflake")) + + # The extraction operator : is left-associative + this = self.expression(exp.GetPath, this=this, expression=path) + + if target_type: + this = exp.cast(this, target_type) + + if not self._match(TokenType.COLON): + break + + if self._match_set(self.RANGE_PARSERS): + this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this + + return this + + +def _parse_timestamp_from_parts(args: t.List) -> exp.Func: + if len(args) == 2: + # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, + # so we parse this into Anonymous for now instead of introducing complexity + return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) + + return exp.TimestampFromParts.from_arg_list(args) + + +def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: + """ + Snowflake doesn't allow columns referenced in UNPIVOT to be qualified, + so we need to unqualify them. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))") + >>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake")) + SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april)) + """ + if isinstance(expression, exp.Pivot) and expression.unpivot: + expression = transforms.unqualify_columns(expression) + + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -211,6 +336,8 @@ class Snowflake(Dialect): TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_SEMI_ANTI_JOIN = False + PREFER_CTE_ALIAS_COLUMN = True + TABLESAMPLE_SIZE_IS_PERCENT = True TIME_MAPPING = { "YYYY": "%Y", @@ -276,14 +403,19 @@ class Snowflake(Dialect): "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _parse_convert_timezone, - "DATE_TRUNC": date_trunc_to_time, + "DATE_TRUNC": _date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=_map_date_part(seq_get(args, 0)), ), "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "FLATTEN": exp.Explode.from_arg_list, "IFF": exp.If.from_arg_list, + "LAST_DAY": lambda args: exp.LastDay( + this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) + ), "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, @@ -293,6 +425,8 @@ class Snowflake(Dialect): "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, + "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts, + "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts, "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, @@ -301,22 +435,17 @@ class Snowflake(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, + "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(), } FUNCTION_PARSERS.pop("TRIM") - COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, - TokenType.COLON: lambda self, this, path: self.expression( - exp.Bracket, this=this, expressions=[path] - ), - } - TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME} RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), + TokenType.COLON: _parse_colon_get_path, } ALTER_PARSERS = { @@ -344,6 +473,7 @@ class Snowflake(Dialect): SHOW_PARSERS = { "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "COLUMNS": _show_parser("COLUMNS"), } STAGED_FILE_SINGLE_TOKENS = { @@ -351,8 +481,18 @@ class Snowflake(Dialect): TokenType.MOD, TokenType.SLASH, } + FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: + if is_map: + # Keys are strings in Snowflake's objects, see also: + # - https://docs.snowflake.com/en/sql-reference/data-types-semistructured + # - https://docs.snowflake.com/en/sql-reference/functions/object_construct + return self._parse_slice(self._parse_string()) + + return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True)) + def _parse_lateral(self) -> t.Optional[exp.Lateral]: lateral = super()._parse_lateral() if not lateral: @@ -440,6 +580,8 @@ class Snowflake(Dialect): scope = None scope_kind = None + like = self._parse_string() if self._match(TokenType.LIKE) else None + if self._match(TokenType.IN): if self._match_text_seq("ACCOUNT"): scope_kind = "ACCOUNT" @@ -451,7 +593,9 @@ class Snowflake(Dialect): scope_kind = "TABLE" scope = self._parse_table() - return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + return self.expression( + exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind + ) def _parse_alter_table_swap(self) -> exp.SwapTable: self._match_text_seq("WITH") @@ -489,8 +633,12 @@ class Snowflake(Dialect): "MINUS": TokenType.EXCEPT, "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, + "REMOVE": TokenType.COMMAND, "RENAME": TokenType.REPLACE, + "RM": TokenType.COMMAND, "SAMPLE": TokenType.TABLE_SAMPLE, + "SQL_DOUBLE": TokenType.DOUBLE, + "SQL_VARCHAR": TokenType.VARCHAR, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -518,6 +666,8 @@ class Snowflake(Dialect): SUPPORTS_TABLE_COPY = False COLLATE_IS_FUNC = True LIMIT_ONLY_LITERALS = True + JSON_KEY_VALUE_PAIR_SEP = "," + INSERT_OVERWRITE = " OVERWRITE INTO" TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -545,6 +695,8 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), + exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]", + exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -557,6 +709,7 @@ class Snowflake(Dialect): exp.PercentileDisc: transforms.preprocess( [transforms.add_within_group_for_percentiles] ), + exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]), exp.RegexpILike: _regexpilike_sql, exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( @@ -578,6 +731,9 @@ class Snowflake(Dialect): *(arg for expression in e.expressions for arg in expression.flatten()), ), exp.Stuff: rename_func("INSERT"), + exp.TimestampDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.unit, e.expression, e.this + ), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( @@ -589,8 +745,7 @@ class Snowflake(Dialect): exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), - exp.UnixToTime: _unix_to_time_sql, + exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.Xor: rename_func("BOOLXOR"), @@ -612,6 +767,14 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + milli = expression.args.get("milli") + if milli is not None: + milli_to_nano = milli.pop() * exp.Literal.number(1000000) + expression.set("nano", milli_to_nano) + + return rename_func("TIMESTAMP_FROM_PARTS")(self, expression) + def trycast_sql(self, expression: exp.TryCast) -> str: value = expression.this @@ -657,6 +820,9 @@ class Snowflake(Dialect): return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: + like = self.sql(expression, "like") + like = f" LIKE {like}" if like else "" + scope = self.sql(expression, "scope") scope = f" {scope}" if scope else "" @@ -664,7 +830,7 @@ class Snowflake(Dialect): if scope_kind: scope_kind = f" IN {scope_kind}" - return f"SHOW {expression.name}{scope_kind}{scope}" + return f"SHOW {expression.name}{like}{scope_kind}{scope}" def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to |