From fb7e79eb4c8d6e22b7324de4bb1ea9cd11b8da7c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 7 Apr 2023 14:35:04 +0200 Subject: Merging upstream version 11.5.2. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/functions.py | 4 +- sqlglot/dialects/bigquery.py | 3 ++ sqlglot/dialects/clickhouse.py | 13 +++-- sqlglot/dialects/hive.py | 4 ++ sqlglot/dialects/mysql.py | 6 +-- sqlglot/dialects/oracle.py | 2 - sqlglot/dialects/redshift.py | 3 -- sqlglot/dialects/snowflake.py | 9 +++- sqlglot/dialects/tsql.py | 2 + sqlglot/executor/env.py | 7 +++ sqlglot/expressions.py | 41 ++++++++++----- sqlglot/generator.py | 15 +++++- sqlglot/optimizer/qualify_columns.py | 21 +++++--- sqlglot/optimizer/simplify.py | 5 +- sqlglot/parser.py | 96 +++++++++++++++++++++++++++++++++++- sqlglot/tokens.py | 9 ++-- sqlglot/transforms.py | 4 +- 18 files changed, 199 insertions(+), 47 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index b53b261..1feb464 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -47,7 +47,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.4.5" +__version__ = "11.5.2" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 3c98f42..f77b4f8 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -1036,8 +1036,8 @@ def from_json( def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: if options is not None: options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "TO_JSON", options_col) - return Column.invoke_anonymous_function(col, "TO_JSON") + return Column.invoke_expression_over_column(col, expression.JSONFormat, options=options_col) + return Column.invoke_expression_over_column(col, expression.JSONFormat) def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index a3f9e6d..701377b 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -221,6 +221,9 @@ class BigQuery(Dialect): **generator.Generator.TRANSFORMS, # type: ignore **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.AtTimeZone: lambda self, e: self.func( + "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) + ), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 89e2296..b06462c 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -46,18 +46,22 @@ class ClickHouse(Dialect): time=seq_get(args, 1), decay=seq_get(params, 0), ), - "MAP": parse_var_map, - "HISTOGRAM": lambda params, args: exp.Histogram( - this=seq_get(args, 0), bins=seq_get(params, 0) - ), "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray( this=seq_get(args, 0), size=seq_get(params, 0) ), + "HISTOGRAM": lambda params, args: exp.Histogram( + this=seq_get(args, 0), bins=seq_get(params, 0) + ), + "MAP": parse_var_map, + "MATCH": exp.RegexpLike.from_arg_list, "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), } + FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() + FUNCTION_PARSERS.pop("MATCH") + RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN) @@ -135,6 +139,7 @@ class ClickHouse(Dialect): exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 68137ae..c39656e 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -261,6 +261,7 @@ class Hive(Dialect): "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), + "TO_JSON": exp.JSONFormat.from_arg_list, "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } @@ -281,6 +282,7 @@ class Hive(Dialect): exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.BIT: "BOOLEAN", } TRANSFORMS = { @@ -305,6 +307,7 @@ class Hive(Dialect): exp.Join: _unnest_to_explode_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), + exp.JSONFormat: rename_func("TO_JSON"), exp.Map: var_map_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, @@ -343,6 +346,7 @@ class Hive(Dialect): exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.LastDateOfMonth: rename_func("LAST_DAY"), + exp.National: lambda self, e: self.sql(e, "this"), } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5dfa811..d64efbf 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -429,7 +429,7 @@ class MySQL(Dialect): LIMIT_FETCH = "LIMIT" - def show_sql(self, expression): + def show_sql(self, expression: exp.Show) -> str: this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" global_ = " GLOBAL" if expression.args.get("global") else "" @@ -469,13 +469,13 @@ class MySQL(Dialect): return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" - def _prefixed_sql(self, prefix, expression, arg): + def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str: sql = self.sql(expression, arg) if not sql: return "" return f" {prefix} {sql}" - def _oldstyle_limit_sql(self, expression): + def _oldstyle_limit_sql(self, expression: exp.Show) -> str: limit = self.sql(expression, "limit") offset = self.sql(expression, "offset") if limit: diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index fad6c4a..3819b76 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -70,7 +70,6 @@ class Oracle(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "DECODE": exp.Matches.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), } @@ -122,7 +121,6 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, # type: ignore exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, - exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index ebd5216..63c14f4 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -3,7 +3,6 @@ from __future__ import annotations import typing as t from sqlglot import exp, transforms -from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -30,7 +29,6 @@ class Redshift(Postgres): expression=seq_get(args, 1), unit=seq_get(args, 0), ), - "DECODE": exp.Matches.from_arg_list, "NVL": exp.Coalesce.from_arg_list, } @@ -89,7 +87,6 @@ class Redshift(Postgres): ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), - exp.Matches: rename_func("DECODE"), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index c50961c..34bc3bd 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -179,6 +179,10 @@ class Snowflake(Dialect): "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone( + this=seq_get(args, 1), + zone=seq_get(args, 0), + ), "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), @@ -190,7 +194,6 @@ class Snowflake(Dialect): expression=seq_get(args, 1), unit=seq_get(args, 0), ), - "DECODE": exp.Matches.from_arg_list, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, @@ -275,6 +278,9 @@ class Snowflake(Dialect): exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), + exp.AtTimeZone: lambda self, e: self.func( + "CONVERT_TIMEZONE", e.args.get("zone"), e.this + ), exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), exp.DateDiff: lambda self, e: self.func( "DATEDIFF", e.text("unit"), e.expression, e.this @@ -287,7 +293,6 @@ class Snowflake(Dialect): exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.Matches: rename_func("DECODE"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 8e9b6c3..b8a227b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -294,6 +294,8 @@ class TSQL(Dialect): "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, + "SUSER_NAME": exp.CurrentUser.from_arg_list, + "SUSER_SNAME": exp.CurrentUser.from_arg_list, } VAR_LENGTH_DATATYPES = { diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index ba9cbbd..8f64cce 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -173,4 +173,11 @@ ENV = { "SUBSTRING": substring, "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)), "UPPER": null_if_any(lambda arg: arg.upper()), + "YEAR": null_if_any(lambda arg: arg.year), + "MONTH": null_if_any(lambda arg: arg.month), + "DAY": null_if_any(lambda arg: arg.day), + "CURRENTDATETIME": datetime.datetime.now, + "CURRENTTIMESTAMP": datetime.datetime.now, + "CURRENTTIME": datetime.datetime.now, + "CURRENTDATE": datetime.date.today, } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index f4aae47..9011dce 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -948,12 +948,17 @@ class Column(Condition): return Dot.build(parts) +class ColumnPosition(Expression): + arg_types = {"this": False, "position": True} + + class ColumnDef(Expression): arg_types = { "this": True, "kind": False, "constraints": False, "exists": False, + "position": False, } @@ -3290,6 +3295,13 @@ class Anonymous(Func): is_var_len_args = True +# https://docs.snowflake.com/en/sql-reference/functions/hll +# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html +class Hll(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + class ApproxDistinct(AggFunc): arg_types = {"this": True, "accuracy": False} @@ -3440,6 +3452,10 @@ class CurrentTimestamp(Func): arg_types = {"this": False} +class CurrentUser(Func): + arg_types = {"this": False} + + class DateAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -3647,6 +3663,11 @@ class JSONBExtractScalar(JSONExtract): _sql_names = ["JSONB_EXTRACT_SCALAR"] +class JSONFormat(Func): + arg_types = {"this": False, "options": False} + _sql_names = ["JSON_FORMAT"] + + class Least(Func): arg_types = {"expressions": False} is_var_len_args = True @@ -3703,14 +3724,9 @@ class VarMap(Func): is_var_len_args = True -class Matches(Func): - """Oracle/Snowflake decode. - https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm - Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else) - """ - - arg_types = {"this": True, "expressions": True} - is_var_len_args = True +# https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html +class MatchAgainst(Func): + arg_types = {"this": True, "expressions": True, "modifier": False} class Max(AggFunc): @@ -4989,9 +5005,10 @@ def replace_placeholders(expression, *args, **kwargs): Examples: >>> from sqlglot import exp, parse_one >>> replace_placeholders( - ... parse_one("select * from :tbl where ? = ?"), "a", "b", tbl="foo" + ... parse_one("select * from :tbl where ? = ?"), + ... exp.to_identifier("str_col"), "b", tbl=exp.to_identifier("foo") ... ).sql() - 'SELECT * FROM foo WHERE a = b' + "SELECT * FROM foo WHERE str_col = 'b'" Returns: The mapped expression. @@ -5002,10 +5019,10 @@ def replace_placeholders(expression, *args, **kwargs): if node.name: new_name = kwargs.get(node.name) if new_name: - return to_identifier(new_name) + return convert(new_name) else: try: - return to_identifier(next(args)) + return convert(next(args)) except StopIteration: pass return node diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 6871dd8..8a49d55 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -466,6 +466,12 @@ class Generator: if part ) + def columnposition_sql(self, expression: exp.ColumnPosition) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + position = self.sql(expression, "position") + return f"{position}{this}" + def columndef_sql(self, expression: exp.ColumnDef) -> str: column = self.sql(expression, "this") kind = self.sql(expression, "kind") @@ -473,8 +479,10 @@ class Generator: exists = "IF NOT EXISTS " if expression.args.get("exists") else "" kind = f" {kind}" if kind else "" constraints = f" {constraints}" if constraints else "" + position = self.sql(expression, "position") + position = f" {position}" if position else "" - return f"{exists}{column}{kind}{constraints}" + return f"{exists}{column}{kind}{constraints}{position}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") @@ -1591,6 +1599,11 @@ class Generator: exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) ) + def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: + modifier = expression.args.get("modifier") + modifier = f" {modifier}" if modifier else "" + return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})" + def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}" diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 5e40cf3..6eae2b5 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -143,12 +143,12 @@ def _expand_alias_refs(scope, resolver): selects = {} # Replace references to select aliases - def transform(node, *_): + def transform(node, source_first=True): if isinstance(node, exp.Column) and not node.table: table = resolver.get_table(node.name) # Source columns get priority over select aliases - if table: + if source_first and table: node.set("table", table) return node @@ -163,16 +163,21 @@ def _expand_alias_refs(scope, resolver): select = select.this return select.copy() + node.set("table", table) + elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable): + exp.replace_children(node, transform, source_first) + return node for select in scope.expression.selects: - select.transform(transform, copy=False) - - for modifier in ("where", "group"): - part = scope.expression.args.get(modifier) + transform(select) - if part: - part.transform(transform, copy=False) + for modifier, source_first in ( + ("where", True), + ("group", True), + ("having", False), + ): + transform(scope.expression.args.get(modifier), source_first=source_first) def _expand_group_by(scope, resolver): diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 1ed3ca2..28ae86d 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -347,8 +347,9 @@ def _simplify_binary(expression, a, b): if isinstance(expression, exp.Mul): return exp.Literal.number(a * b) if isinstance(expression, exp.Div): + # engines have differing int div behavior so intdiv is not safe if isinstance(a, int) and isinstance(b, int): - return exp.Literal.number(a // b) + return None return exp.Literal.number(a / b) boolean = eval_boolean(expression, a, b) @@ -491,7 +492,7 @@ def _flat_simplify(expression, simplifier, root=True): if result: queue.remove(b) - queue.append(result) + queue.appendleft(result) break else: operands.append(a) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 8269525..b3b899c 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -105,6 +105,7 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_DATETIME: exp.CurrentDate, TokenType.CURRENT_TIME: exp.CurrentTime, TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, + TokenType.CURRENT_USER: exp.CurrentUser, } NESTED_TYPE_TOKENS = { @@ -285,6 +286,7 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, TokenType.CURRENT_TIME, + TokenType.CURRENT_USER, TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, @@ -674,9 +676,11 @@ class Parser(metaclass=_Parser): FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), + "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), "LOG": lambda self: self._parse_logarithm(), + "MATCH": lambda self: self._parse_match_against(), "POSITION": lambda self: self._parse_position(), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), @@ -2634,7 +2638,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() maybe_func = True - if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): this = exp.DataType( this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value, expressions=expressions)], @@ -2959,6 +2963,11 @@ class Parser(metaclass=_Parser): else: this = self._parse_select_or_expression() + if isinstance(this, exp.EQ): + left = this.this + if isinstance(left, exp.Column): + left.replace(exp.Var(this=left.text("this"))) + if self._match(TokenType.IGNORE_NULLS): this = self.expression(exp.IgnoreNulls, this=this) else: @@ -2968,8 +2977,16 @@ class Parser(metaclass=_Parser): def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index - if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): + + try: + if self._parse_select(nested=True): + return this + except Exception: + pass + finally: self._retreat(index) + + if not self._match(TokenType.L_PAREN): return this args = self._parse_csv( @@ -3344,6 +3361,51 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_decode(self) -> t.Optional[exp.Expression]: + """ + There are generally two variants of the DECODE function: + + - DECODE(bin, charset) + - DECODE(expression, search, result [, search, result] ... [, default]) + + The second variant will always be parsed into a CASE expression. Note that NULL + needs special treatment, since we need to explicitly check for it with `IS NULL`, + instead of relying on pattern matching. + """ + args = self._parse_csv(self._parse_conjunction) + + if len(args) < 3: + return self.expression(exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1)) + + expression, *expressions = args + if not expression: + return None + + ifs = [] + for search, result in zip(expressions[::2], expressions[1::2]): + if not search or not result: + return None + + if isinstance(search, exp.Literal): + ifs.append( + exp.If(this=exp.EQ(this=expression.copy(), expression=search), true=result) + ) + elif isinstance(search, exp.Null): + ifs.append( + exp.If(this=exp.Is(this=expression.copy(), expression=exp.Null()), true=result) + ) + else: + cond = exp.or_( + exp.EQ(this=expression.copy(), expression=search), + exp.and_( + exp.Is(this=expression.copy(), expression=exp.Null()), + exp.Is(this=search.copy(), expression=exp.Null()), + ), + ) + ifs.append(exp.If(this=cond, true=result)) + + return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None) + def _parse_json_key_value(self) -> t.Optional[exp.Expression]: self._match_text_seq("KEY") key = self._parse_field() @@ -3398,6 +3460,28 @@ class Parser(metaclass=_Parser): exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) ) + def _parse_match_against(self) -> exp.Expression: + expressions = self._parse_csv(self._parse_column) + + self._match_text_seq(")", "AGAINST", "(") + + this = self._parse_string() + + if self._match_text_seq("IN", "NATURAL", "LANGUAGE", "MODE"): + modifier = "IN NATURAL LANGUAGE MODE" + if self._match_text_seq("WITH", "QUERY", "EXPANSION"): + modifier = f"{modifier} WITH QUERY EXPANSION" + elif self._match_text_seq("IN", "BOOLEAN", "MODE"): + modifier = "IN BOOLEAN MODE" + elif self._match_text_seq("WITH", "QUERY", "EXPANSION"): + modifier = "WITH QUERY EXPANSION" + else: + modifier = None + + return self.expression( + exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier + ) + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) @@ -3791,6 +3875,14 @@ class Parser(metaclass=_Parser): if expression: expression.set("exists", exists_column) + # https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns + if self._match_texts(("FIRST", "AFTER")): + position = self._prev.text + column_position = self.expression( + exp.ColumnPosition, this=self._parse_column(), position=position + ) + expression.set("position", column_position) + return expression def _parse_drop_column(self) -> t.Optional[exp.Expression]: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index e5b44e7..cf2e31f 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -163,6 +163,7 @@ class TokenType(AutoName): CURRENT_ROW = auto() CURRENT_TIME = auto() CURRENT_TIMESTAMP = auto() + CURRENT_USER = auto() DEFAULT = auto() DELETE = auto() DESC = auto() @@ -506,6 +507,7 @@ class Tokenizer(metaclass=_Tokenizer): "CURRENT ROW": TokenType.CURRENT_ROW, "CURRENT_TIME": TokenType.CURRENT_TIME, "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, + "CURRENT_USER": TokenType.CURRENT_USER, "DATABASE": TokenType.DATABASE, "DEFAULT": TokenType.DEFAULT, "DELETE": TokenType.DELETE, @@ -908,7 +910,7 @@ class Tokenizer(metaclass=_Tokenizer): if not word: if self._char in self.SINGLE_TOKENS: - self._add(self.SINGLE_TOKENS[self._char]) # type: ignore + self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore return self._scan_var() return @@ -921,7 +923,8 @@ class Tokenizer(metaclass=_Tokenizer): return self._advance(size - 1) - self._add(self.KEYWORDS[word.upper()]) + word = word.upper() + self._add(self.KEYWORDS[word], text=word) def _scan_comment(self, comment_start: str) -> bool: if comment_start not in self._COMMENTS: # type: ignore @@ -946,7 +949,7 @@ class Tokenizer(metaclass=_Tokenizer): # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. # Multiple consecutive comments are preserved by appending them to the current comments list. - if comment_start_line == self._prev_token_line: + if comment_start_line == self._prev_token_line or self._end: self.tokens[-1].comments.extend(self._comments) self._comments = [] self._prev_token_line = self._line diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 2eafb0b..62728d5 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -114,8 +114,8 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: """ - Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. - This transforms removes the precision from parameterized types in expressions. + Some dialects only allow the precision for parameterized types to be defined in the DDL and not in + other expressions. This transforms removes the precision from parameterized types in expressions. """ return expression.transform( lambda node: exp.DataType( -- cgit v1.2.3