summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-04-07 12:35:04 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-04-07 12:35:04 +0000
commitfb7e79eb4c8d6e22b7324de4bb1ea9cd11b8da7c (patch)
tree476513580a6824dfe34364f98f0dbf7f66d188f4 /sqlglot
parentReleasing debian version 11.4.5-1. (diff)
downloadsqlglot-fb7e79eb4c8d6e22b7324de4bb1ea9cd11b8da7c.tar.xz
sqlglot-fb7e79eb4c8d6e22b7324de4bb1ea9cd11b8da7c.zip
Merging upstream version 11.5.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/functions.py4
-rw-r--r--sqlglot/dialects/bigquery.py3
-rw-r--r--sqlglot/dialects/clickhouse.py13
-rw-r--r--sqlglot/dialects/hive.py4
-rw-r--r--sqlglot/dialects/mysql.py6
-rw-r--r--sqlglot/dialects/oracle.py2
-rw-r--r--sqlglot/dialects/redshift.py3
-rw-r--r--sqlglot/dialects/snowflake.py9
-rw-r--r--sqlglot/dialects/tsql.py2
-rw-r--r--sqlglot/executor/env.py7
-rw-r--r--sqlglot/expressions.py41
-rw-r--r--sqlglot/generator.py15
-rw-r--r--sqlglot/optimizer/qualify_columns.py21
-rw-r--r--sqlglot/optimizer/simplify.py5
-rw-r--r--sqlglot/parser.py96
-rw-r--r--sqlglot/tokens.py9
-rw-r--r--sqlglot/transforms.py4
18 files changed, 199 insertions, 47 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index b53b261..1feb464 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.4.5"
+__version__ = "11.5.2"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 3c98f42..f77b4f8 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -1036,8 +1036,8 @@ def from_json(
def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
if options is not None:
options_col = create_map([lit(x) for x in _flatten(options.items())])
- return Column.invoke_anonymous_function(col, "TO_JSON", options_col)
- return Column.invoke_anonymous_function(col, "TO_JSON")
+ return Column.invoke_expression_over_column(col, expression.JSONFormat, options=options_col)
+ return Column.invoke_expression_over_column(col, expression.JSONFormat)
def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index a3f9e6d..701377b 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -221,6 +221,9 @@ class BigQuery(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
+ exp.AtTimeZone: lambda self, e: self.func(
+ "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
+ ),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 89e2296..b06462c 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -46,18 +46,22 @@ class ClickHouse(Dialect):
time=seq_get(args, 1),
decay=seq_get(params, 0),
),
- "MAP": parse_var_map,
- "HISTOGRAM": lambda params, args: exp.Histogram(
- this=seq_get(args, 0), bins=seq_get(params, 0)
- ),
"GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
this=seq_get(args, 0), size=seq_get(params, 0)
),
+ "HISTOGRAM": lambda params, args: exp.Histogram(
+ this=seq_get(args, 0), bins=seq_get(params, 0)
+ ),
+ "MAP": parse_var_map,
+ "MATCH": exp.RegexpLike.from_arg_list,
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
}
+ FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
+ FUNCTION_PARSERS.pop("MATCH")
+
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
@@ -135,6 +139,7 @@ class ClickHouse(Dialect):
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
+ exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 68137ae..c39656e 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -261,6 +261,7 @@ class Hive(Dialect):
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
+ "TO_JSON": exp.JSONFormat.from_arg_list,
"UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@@ -281,6 +282,7 @@ class Hive(Dialect):
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.BIT: "BOOLEAN",
}
TRANSFORMS = {
@@ -305,6 +307,7 @@ class Hive(Dialect):
exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
+ exp.JSONFormat: rename_func("TO_JSON"),
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@@ -343,6 +346,7 @@ class Hive(Dialect):
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
exp.LastDateOfMonth: rename_func("LAST_DAY"),
+ exp.National: lambda self, e: self.sql(e, "this"),
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5dfa811..d64efbf 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -429,7 +429,7 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
- def show_sql(self, expression):
+ def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
global_ = " GLOBAL" if expression.args.get("global") else ""
@@ -469,13 +469,13 @@ class MySQL(Dialect):
return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
- def _prefixed_sql(self, prefix, expression, arg):
+ def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str:
sql = self.sql(expression, arg)
if not sql:
return ""
return f" {prefix} {sql}"
- def _oldstyle_limit_sql(self, expression):
+ def _oldstyle_limit_sql(self, expression: exp.Show) -> str:
limit = self.sql(expression, "limit")
offset = self.sql(expression, "offset")
if limit:
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index fad6c4a..3819b76 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -70,7 +70,6 @@ class Oracle(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
- "DECODE": exp.Matches.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
}
@@ -122,7 +121,6 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.ILike: no_ilike_sql,
- exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index ebd5216..63c14f4 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -3,7 +3,6 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
-from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -30,7 +29,6 @@ class Redshift(Postgres):
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
- "DECODE": exp.Matches.from_arg_list,
"NVL": exp.Coalesce.from_arg_list,
}
@@ -89,7 +87,6 @@ class Redshift(Postgres):
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
- exp.Matches: rename_func("DECODE"),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index c50961c..34bc3bd 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -179,6 +179,10 @@ class Snowflake(Dialect):
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
+ "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone(
+ this=seq_get(args, 1),
+ zone=seq_get(args, 0),
+ ),
"DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
@@ -190,7 +194,6 @@ class Snowflake(Dialect):
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
- "DECODE": exp.Matches.from_arg_list,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
@@ -275,6 +278,9 @@ class Snowflake(Dialect):
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
+ exp.AtTimeZone: lambda self, e: self.func(
+ "CONVERT_TIMEZONE", e.args.get("zone"), e.this
+ ),
exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
exp.DateDiff: lambda self, e: self.func(
"DATEDIFF", e.text("unit"), e.expression, e.this
@@ -287,7 +293,6 @@ class Snowflake(Dialect):
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
- exp.Matches: rename_func("DECODE"),
exp.StrPosition: lambda self, e: self.func(
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 8e9b6c3..b8a227b 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -294,6 +294,8 @@ class TSQL(Dialect):
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
+ "SUSER_NAME": exp.CurrentUser.from_arg_list,
+ "SUSER_SNAME": exp.CurrentUser.from_arg_list,
}
VAR_LENGTH_DATATYPES = {
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index ba9cbbd..8f64cce 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -173,4 +173,11 @@ ENV = {
"SUBSTRING": substring,
"TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)),
"UPPER": null_if_any(lambda arg: arg.upper()),
+ "YEAR": null_if_any(lambda arg: arg.year),
+ "MONTH": null_if_any(lambda arg: arg.month),
+ "DAY": null_if_any(lambda arg: arg.day),
+ "CURRENTDATETIME": datetime.datetime.now,
+ "CURRENTTIMESTAMP": datetime.datetime.now,
+ "CURRENTTIME": datetime.datetime.now,
+ "CURRENTDATE": datetime.date.today,
}
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index f4aae47..9011dce 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -948,12 +948,17 @@ class Column(Condition):
return Dot.build(parts)
+class ColumnPosition(Expression):
+ arg_types = {"this": False, "position": True}
+
+
class ColumnDef(Expression):
arg_types = {
"this": True,
"kind": False,
"constraints": False,
"exists": False,
+ "position": False,
}
@@ -3290,6 +3295,13 @@ class Anonymous(Func):
is_var_len_args = True
+# https://docs.snowflake.com/en/sql-reference/functions/hll
+# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html
+class Hll(AggFunc):
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
+
+
class ApproxDistinct(AggFunc):
arg_types = {"this": True, "accuracy": False}
@@ -3440,6 +3452,10 @@ class CurrentTimestamp(Func):
arg_types = {"this": False}
+class CurrentUser(Func):
+ arg_types = {"this": False}
+
+
class DateAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -3647,6 +3663,11 @@ class JSONBExtractScalar(JSONExtract):
_sql_names = ["JSONB_EXTRACT_SCALAR"]
+class JSONFormat(Func):
+ arg_types = {"this": False, "options": False}
+ _sql_names = ["JSON_FORMAT"]
+
+
class Least(Func):
arg_types = {"expressions": False}
is_var_len_args = True
@@ -3703,14 +3724,9 @@ class VarMap(Func):
is_var_len_args = True
-class Matches(Func):
- """Oracle/Snowflake decode.
- https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm
- Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else)
- """
-
- arg_types = {"this": True, "expressions": True}
- is_var_len_args = True
+# https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html
+class MatchAgainst(Func):
+ arg_types = {"this": True, "expressions": True, "modifier": False}
class Max(AggFunc):
@@ -4989,9 +5005,10 @@ def replace_placeholders(expression, *args, **kwargs):
Examples:
>>> from sqlglot import exp, parse_one
>>> replace_placeholders(
- ... parse_one("select * from :tbl where ? = ?"), "a", "b", tbl="foo"
+ ... parse_one("select * from :tbl where ? = ?"),
+ ... exp.to_identifier("str_col"), "b", tbl=exp.to_identifier("foo")
... ).sql()
- 'SELECT * FROM foo WHERE a = b'
+ "SELECT * FROM foo WHERE str_col = 'b'"
Returns:
The mapped expression.
@@ -5002,10 +5019,10 @@ def replace_placeholders(expression, *args, **kwargs):
if node.name:
new_name = kwargs.get(node.name)
if new_name:
- return to_identifier(new_name)
+ return convert(new_name)
else:
try:
- return to_identifier(next(args))
+ return convert(next(args))
except StopIteration:
pass
return node
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 6871dd8..8a49d55 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -466,6 +466,12 @@ class Generator:
if part
)
+ def columnposition_sql(self, expression: exp.ColumnPosition) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ position = self.sql(expression, "position")
+ return f"{position}{this}"
+
def columndef_sql(self, expression: exp.ColumnDef) -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
@@ -473,8 +479,10 @@ class Generator:
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
kind = f" {kind}" if kind else ""
constraints = f" {constraints}" if constraints else ""
+ position = self.sql(expression, "position")
+ position = f" {position}" if position else ""
- return f"{exists}{column}{kind}{constraints}"
+ return f"{exists}{column}{kind}{constraints}{position}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
@@ -1591,6 +1599,11 @@ class Generator:
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
+ def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
+ modifier = expression.args.get("modifier")
+ modifier = f" {modifier}" if modifier else ""
+ return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})"
+
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 5e40cf3..6eae2b5 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -143,12 +143,12 @@ def _expand_alias_refs(scope, resolver):
selects = {}
# Replace references to select aliases
- def transform(node, *_):
+ def transform(node, source_first=True):
if isinstance(node, exp.Column) and not node.table:
table = resolver.get_table(node.name)
# Source columns get priority over select aliases
- if table:
+ if source_first and table:
node.set("table", table)
return node
@@ -163,16 +163,21 @@ def _expand_alias_refs(scope, resolver):
select = select.this
return select.copy()
+ node.set("table", table)
+ elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable):
+ exp.replace_children(node, transform, source_first)
+
return node
for select in scope.expression.selects:
- select.transform(transform, copy=False)
-
- for modifier in ("where", "group"):
- part = scope.expression.args.get(modifier)
+ transform(select)
- if part:
- part.transform(transform, copy=False)
+ for modifier, source_first in (
+ ("where", True),
+ ("group", True),
+ ("having", False),
+ ):
+ transform(scope.expression.args.get(modifier), source_first=source_first)
def _expand_group_by(scope, resolver):
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 1ed3ca2..28ae86d 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -347,8 +347,9 @@ def _simplify_binary(expression, a, b):
if isinstance(expression, exp.Mul):
return exp.Literal.number(a * b)
if isinstance(expression, exp.Div):
+ # engines have differing int div behavior so intdiv is not safe
if isinstance(a, int) and isinstance(b, int):
- return exp.Literal.number(a // b)
+ return None
return exp.Literal.number(a / b)
boolean = eval_boolean(expression, a, b)
@@ -491,7 +492,7 @@ def _flat_simplify(expression, simplifier, root=True):
if result:
queue.remove(b)
- queue.append(result)
+ queue.appendleft(result)
break
else:
operands.append(a)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 8269525..b3b899c 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -105,6 +105,7 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_DATETIME: exp.CurrentDate,
TokenType.CURRENT_TIME: exp.CurrentTime,
TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp,
+ TokenType.CURRENT_USER: exp.CurrentUser,
}
NESTED_TYPE_TOKENS = {
@@ -285,6 +286,7 @@ class Parser(metaclass=_Parser):
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
TokenType.CURRENT_TIME,
+ TokenType.CURRENT_USER,
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
@@ -674,9 +676,11 @@ class Parser(metaclass=_Parser):
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
+ "DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
"LOG": lambda self: self._parse_logarithm(),
+ "MATCH": lambda self: self._parse_match_against(),
"POSITION": lambda self: self._parse_position(),
"STRING_AGG": lambda self: self._parse_string_agg(),
"SUBSTRING": lambda self: self._parse_substring(),
@@ -2634,7 +2638,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
maybe_func = True
- if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
+ if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
this = exp.DataType(
this=exp.DataType.Type.ARRAY,
expressions=[exp.DataType.build(type_token.value, expressions=expressions)],
@@ -2959,6 +2963,11 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_select_or_expression()
+ if isinstance(this, exp.EQ):
+ left = this.this
+ if isinstance(left, exp.Column):
+ left.replace(exp.Var(this=left.text("this")))
+
if self._match(TokenType.IGNORE_NULLS):
this = self.expression(exp.IgnoreNulls, this=this)
else:
@@ -2968,8 +2977,16 @@ class Parser(metaclass=_Parser):
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
- if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
+
+ try:
+ if self._parse_select(nested=True):
+ return this
+ except Exception:
+ pass
+ finally:
self._retreat(index)
+
+ if not self._match(TokenType.L_PAREN):
return this
args = self._parse_csv(
@@ -3344,6 +3361,51 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_decode(self) -> t.Optional[exp.Expression]:
+ """
+ There are generally two variants of the DECODE function:
+
+ - DECODE(bin, charset)
+ - DECODE(expression, search, result [, search, result] ... [, default])
+
+ The second variant will always be parsed into a CASE expression. Note that NULL
+ needs special treatment, since we need to explicitly check for it with `IS NULL`,
+ instead of relying on pattern matching.
+ """
+ args = self._parse_csv(self._parse_conjunction)
+
+ if len(args) < 3:
+ return self.expression(exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1))
+
+ expression, *expressions = args
+ if not expression:
+ return None
+
+ ifs = []
+ for search, result in zip(expressions[::2], expressions[1::2]):
+ if not search or not result:
+ return None
+
+ if isinstance(search, exp.Literal):
+ ifs.append(
+ exp.If(this=exp.EQ(this=expression.copy(), expression=search), true=result)
+ )
+ elif isinstance(search, exp.Null):
+ ifs.append(
+ exp.If(this=exp.Is(this=expression.copy(), expression=exp.Null()), true=result)
+ )
+ else:
+ cond = exp.or_(
+ exp.EQ(this=expression.copy(), expression=search),
+ exp.and_(
+ exp.Is(this=expression.copy(), expression=exp.Null()),
+ exp.Is(this=search.copy(), expression=exp.Null()),
+ ),
+ )
+ ifs.append(exp.If(this=cond, true=result))
+
+ return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None)
+
def _parse_json_key_value(self) -> t.Optional[exp.Expression]:
self._match_text_seq("KEY")
key = self._parse_field()
@@ -3398,6 +3460,28 @@ class Parser(metaclass=_Parser):
exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0)
)
+ def _parse_match_against(self) -> exp.Expression:
+ expressions = self._parse_csv(self._parse_column)
+
+ self._match_text_seq(")", "AGAINST", "(")
+
+ this = self._parse_string()
+
+ if self._match_text_seq("IN", "NATURAL", "LANGUAGE", "MODE"):
+ modifier = "IN NATURAL LANGUAGE MODE"
+ if self._match_text_seq("WITH", "QUERY", "EXPANSION"):
+ modifier = f"{modifier} WITH QUERY EXPANSION"
+ elif self._match_text_seq("IN", "BOOLEAN", "MODE"):
+ modifier = "IN BOOLEAN MODE"
+ elif self._match_text_seq("WITH", "QUERY", "EXPANSION"):
+ modifier = "WITH QUERY EXPANSION"
+ else:
+ modifier = None
+
+ return self.expression(
+ exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier
+ )
+
def _parse_position(self, haystack_first: bool = False) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
@@ -3791,6 +3875,14 @@ class Parser(metaclass=_Parser):
if expression:
expression.set("exists", exists_column)
+ # https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns
+ if self._match_texts(("FIRST", "AFTER")):
+ position = self._prev.text
+ column_position = self.expression(
+ exp.ColumnPosition, this=self._parse_column(), position=position
+ )
+ expression.set("position", column_position)
+
return expression
def _parse_drop_column(self) -> t.Optional[exp.Expression]:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index e5b44e7..cf2e31f 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -163,6 +163,7 @@ class TokenType(AutoName):
CURRENT_ROW = auto()
CURRENT_TIME = auto()
CURRENT_TIMESTAMP = auto()
+ CURRENT_USER = auto()
DEFAULT = auto()
DELETE = auto()
DESC = auto()
@@ -506,6 +507,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
+ "CURRENT_USER": TokenType.CURRENT_USER,
"DATABASE": TokenType.DATABASE,
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
@@ -908,7 +910,7 @@ class Tokenizer(metaclass=_Tokenizer):
if not word:
if self._char in self.SINGLE_TOKENS:
- self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
+ self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore
return
self._scan_var()
return
@@ -921,7 +923,8 @@ class Tokenizer(metaclass=_Tokenizer):
return
self._advance(size - 1)
- self._add(self.KEYWORDS[word.upper()])
+ word = word.upper()
+ self._add(self.KEYWORDS[word], text=word)
def _scan_comment(self, comment_start: str) -> bool:
if comment_start not in self._COMMENTS: # type: ignore
@@ -946,7 +949,7 @@ class Tokenizer(metaclass=_Tokenizer):
# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
# Multiple consecutive comments are preserved by appending them to the current comments list.
- if comment_start_line == self._prev_token_line:
+ if comment_start_line == self._prev_token_line or self._end:
self.tokens[-1].comments.extend(self._comments)
self._comments = []
self._prev_token_line = self._line
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 2eafb0b..62728d5 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -114,8 +114,8 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
"""
- Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
- This transforms removes the precision from parameterized types in expressions.
+ Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
+ other expressions. This transforms removes the precision from parameterized types in expressions.
"""
return expression.transform(
lambda node: exp.DataType(