diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/snowflake.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 715a84c..499e085 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -30,7 +30,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -137,7 +137,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.List) -> exp.Expression: +def _div0_to_if(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -145,13 +145,13 @@ def _div0_to_if(args: t.List) -> exp.Expression: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.List) -> exp.Expression: +def _zeroifnull_to_if(args: t.List) -> exp.If: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.List) -> exp.Expression: +def _nullifzero_to_if(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -164,12 +164,21 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _parse_convert_timezone(args: t.List) -> exp.Expression: +def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) +def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: + regexp_replace = exp.RegexpReplace.from_arg_list(args) + + if not regexp_replace.args.get("replacement"): + regexp_replace.set("replacement", exp.Literal.string("")) + + return regexp_replace + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax RESOLVES_IDENTIFIERS_AS_UPPERCASE = True @@ -223,13 +232,14 @@ class Snowflake(Dialect): "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, + "REGEXP_REPLACE": _parse_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, "TO_ARRAY": exp.Array.from_arg_list, - "TO_TIMESTAMP": _snowflake_to_timestamp, + "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, } @@ -242,7 +252,6 @@ class Snowflake(Dialect): FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, - TokenType.RLIKE, TokenType.TABLE, } |