summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/duckdb.py1
-rw-r--r--sqlglot/dialects/postgres.py4
-rw-r--r--sqlglot/dialects/presto.py1
-rw-r--r--sqlglot/dialects/spark.py3
-rw-r--r--sqlglot/expressions.py11
-rw-r--r--sqlglot/generator.py58
-rw-r--r--sqlglot/optimizer/annotate_types.py16
-rw-r--r--sqlglot/parser.py50
-rw-r--r--sqlglot/schema.py23
-rw-r--r--sqlglot/tokens.py20
10 files changed, 133 insertions, 54 deletions
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index d7ba729..e61ac4f 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -333,6 +333,7 @@ class DuckDB(Dialect):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
+ MULTI_ARG_DISTINCT = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 0404c78..68e2c6d 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -232,6 +232,9 @@ class Postgres(Dialect):
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
HEREDOC_STRINGS = ["$"]
+ HEREDOC_TAG_IS_IDENTIFIER = True
+ HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"~~": TokenType.LIKE,
@@ -381,6 +384,7 @@ class Postgres(Dialect):
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
SUPPORTS_UNLOGGED_TABLES = True
LIKE_PROPERTY_INSIDE_SCHEMA = True
+ MULTI_ARG_DISTINCT = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 8691192..609103e 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -292,6 +292,7 @@ class Presto(Dialect):
LIMIT_ONLY_LITERALS = True
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
+ MULTI_ARG_DISTINCT = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 4c5c131..44bd12d 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -50,9 +50,6 @@ class Spark(Spark2):
"DATEDIFF": _parse_datediff,
}
- FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
- FUNCTION_PARSERS.pop("ANY_VALUE")
-
def _parse_generated_as_identity(
self,
) -> (
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 3234c99..11ebbaf 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1796,7 +1796,7 @@ class Lambda(Expression):
class Limit(Expression):
- arg_types = {"this": False, "expression": True, "offset": False}
+ arg_types = {"this": False, "expression": True, "offset": False, "expressions": False}
class Literal(Condition):
@@ -1969,7 +1969,7 @@ class Final(Expression):
class Offset(Expression):
- arg_types = {"this": False, "expression": True}
+ arg_types = {"this": False, "expression": True, "expressions": False}
class Order(Expression):
@@ -4291,6 +4291,11 @@ class RespectNulls(Expression):
pass
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate-function-calls#max_min_clause
+class HavingMax(Expression):
+ arg_types = {"this": True, "expression": True, "max": True}
+
+
# Functions
class Func(Condition):
"""
@@ -4491,7 +4496,7 @@ class Avg(AggFunc):
class AnyValue(AggFunc):
- arg_types = {"this": True, "having": False, "max": False}
+ pass
class Lag(AggFunc):
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 568dcb4..318d782 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -296,6 +296,10 @@ class Generator(metaclass=_Generator):
# Whether or not the LikeProperty needs to be specified inside of the schema clause
LIKE_PROPERTY_INSIDE_SCHEMA = False
+ # Whether or not DISTINCT can be followed by multiple args in an AggFunc. If not, it will be
+ # transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args
+ MULTI_ARG_DISTINCT = True
+
# Whether or not the JSON extraction operators expect a value of type JSON
JSON_TYPE_REQUIRED_FOR_EXTRACTION = False
@@ -1841,15 +1845,18 @@ class Generator(metaclass=_Generator):
args_sql = ", ".join(self.sql(e) for e in args)
args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
- return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}"
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" BY {expressions}" if expressions else ""
+
+ return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}{expressions}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
- expression = expression.expression
- expression = (
- self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression
- )
- return f"{this}{self.seg('OFFSET')} {self.sql(expression)}"
+ value = expression.expression
+ value = self._simplify_unless_literal(value) if self.LIMIT_ONLY_LITERALS else value
+ expressions = self.expressions(expression, flat=True)
+ expressions = f" BY {expressions}" if expressions else ""
+ return f"{this}{self.seg('OFFSET')} {self.sql(value)}{expressions}"
def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
@@ -2834,6 +2841,13 @@ class Generator(metaclass=_Generator):
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
+
+ if not self.MULTI_ARG_DISTINCT and len(expression.expressions) > 1:
+ case = exp.case()
+ for arg in expression.expressions:
+ case = case.when(arg.is_(exp.null()), exp.null())
+ this = self.sql(case.else_(f"({this})"))
+
this = f" {this}" if this else ""
on = self.sql(expression, "on")
@@ -2846,13 +2860,33 @@ class Generator(metaclass=_Generator):
def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return self._embed_ignore_nulls(expression, "RESPECT NULLS")
+ def havingmax_sql(self, expression: exp.HavingMax) -> str:
+ this_sql = self.sql(expression, "this")
+ expression_sql = self.sql(expression, "expression")
+ kind = "MAX" if expression.args.get("max") else "MIN"
+ return f"{this_sql} HAVING {kind} {expression_sql}"
+
def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
- if self.IGNORE_NULLS_IN_FUNC:
- this = expression.find(exp.AggFunc)
- if this:
- sql = self.sql(this)
- sql = sql[:-1] + f" {text})"
- return sql
+ if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
+ # The first modifier here will be the one closest to the AggFunc's arg
+ mods = sorted(
+ expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
+ key=lambda x: 0
+ if isinstance(x, exp.HavingMax)
+ else (1 if isinstance(x, exp.Order) else 2),
+ )
+
+ if mods:
+ mod = mods[0]
+ this = expression.__class__(this=mod.this.copy())
+ this.meta["inline"] = True
+ mod.this.replace(this)
+ return self.sql(expression.this)
+
+ agg_func = expression.find(exp.AggFunc)
+
+ if agg_func:
+ return self.sql(agg_func)[:-1] + f" {text})"
return f"{self.sql(expression, 'this')} {text}"
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index a2a86cd..cb9312c 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -263,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
+ exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
@@ -333,9 +334,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._visited: t.Set[int] = set()
def _set_type(
- self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
+ self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
) -> None:
- expression.type = target_type # type: ignore
+ expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
@@ -564,13 +565,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(bracket_arg, exp.Slice):
self._set_type(expression, this.type)
elif this.type.is_type(exp.DataType.Type.ARRAY):
- contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
- self._set_type(expression, contained_type)
+ self._set_type(expression, seq_get(this.type.expressions, 0))
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
index = this.keys.index(bracket_arg)
value = seq_get(this.values, index)
- value_type = value.type if value else exp.DataType.Type.UNKNOWN
- self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
+ self._set_type(expression, value.type if value else None)
else:
self._set_type(expression, exp.DataType.Type.UNKNOWN)
@@ -591,3 +590,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, self._maybe_coerce(left_type, right_type))
return expression
+
+ def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
+ self._annotate_args(expression)
+ self._set_type(expression, seq_get(expression.this.type.expressions, 0))
+ return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index a89e4fa..dfa3024 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -872,7 +872,6 @@ class Parser(metaclass=_Parser):
FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"}
FUNCTION_PARSERS = {
- "ANY_VALUE": lambda self: self._parse_any_value(),
"CAST": lambda self: self._parse_cast(self.STRICT_CAST),
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"DECODE": lambda self: self._parse_decode(),
@@ -2465,8 +2464,14 @@ class Parser(metaclass=_Parser):
this.set(key, expression)
if key == "limit":
offset = expression.args.pop("offset", None)
+
if offset:
- this.set("offset", exp.Offset(expression=offset))
+ offset = exp.Offset(expression=offset)
+ this.set("offset", offset)
+
+ limit_by_expressions = expression.expressions
+ expression.set("expressions", None)
+ offset.set("expressions", limit_by_expressions)
continue
break
return this
@@ -3341,7 +3346,12 @@ class Parser(metaclass=_Parser):
offset = None
limit_exp = self.expression(
- exp.Limit, this=this, expression=expression, offset=offset, comments=comments
+ exp.Limit,
+ this=this,
+ expression=expression,
+ offset=offset,
+ comments=comments,
+ expressions=self._parse_limit_by(),
)
return limit_exp
@@ -3377,7 +3387,13 @@ class Parser(metaclass=_Parser):
count = self._parse_term()
self._match_set((TokenType.ROW, TokenType.ROWS))
- return self.expression(exp.Offset, this=this, expression=count)
+
+ return self.expression(
+ exp.Offset, this=this, expression=count, expressions=self._parse_limit_by()
+ )
+
+ def _parse_limit_by(self) -> t.Optional[t.List[exp.Expression]]:
+ return self._match_text_seq("BY") and self._parse_csv(self._parse_bitwise)
def _parse_locks(self) -> t.List[exp.Lock]:
locks = []
@@ -4115,7 +4131,9 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_select_or_expression(alias=alias)
- return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this)))
+ return self._parse_limit(
+ self._parse_order(self._parse_having_max(self._parse_respect_or_ignore_nulls(this)))
+ )
def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
@@ -4549,18 +4567,6 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
- def _parse_any_value(self) -> exp.AnyValue:
- this = self._parse_lambda()
- is_max = None
- having = None
-
- if self._match(TokenType.HAVING):
- self._match_texts(("MAX", "MIN"))
- is_max = self._prev.text == "MAX"
- having = self._parse_column()
-
- return self.expression(exp.AnyValue, this=this, having=having, max=is_max)
-
def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression:
this = self._parse_conjunction()
@@ -4941,6 +4947,16 @@ class Parser(metaclass=_Parser):
return self.expression(exp.RespectNulls, this=this)
return this
+ def _parse_having_max(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ if self._match(TokenType.HAVING):
+ self._match_texts(("MAX", "MIN"))
+ max = self._prev.text.upper() != "MIN"
+ return self.expression(
+ exp.HavingMax, this=this, expression=self._parse_column(), max=max
+ )
+
+ return this
+
def _parse_window(
self, this: t.Optional[exp.Expression], alias: bool = False
) -> t.Optional[exp.Expression]:
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 13f72d6..1fd4025 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -106,19 +106,6 @@ class Schema(abc.ABC):
name = column if isinstance(column, str) else column.name
return name in self.column_names(table, dialect=dialect, normalize=normalize)
- @abc.abstractmethod
- def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
- """
- Returns the schema of a given table.
-
- Args:
- table: the target table.
- raise_on_missing: whether or not to raise in case the schema is not found.
-
- Returns:
- The schema of the target table.
- """
-
@property
@abc.abstractmethod
def supported_table_args(self) -> t.Tuple[str, ...]:
@@ -170,6 +157,16 @@ class AbstractMappingSchema:
return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)]
def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]:
+ """
+ Returns the schema of a given table.
+
+ Args:
+ table: the target table.
+ raise_on_missing: whether or not to raise in case the schema is not found.
+
+ Returns:
+ The schema of the target table.
+ """
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(self.mapping_trie, parts)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 87a4924..b064957 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -504,6 +504,7 @@ class _Tokenizer(type):
command_prefix_tokens={
_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS
},
+ heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER,
)
token_types = RsTokenTypeSettings(
bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING],
@@ -517,6 +518,7 @@ class _Tokenizer(type):
semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON],
string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING],
var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR],
+ heredoc_string_alternative=_TOKEN_TYPE_TO_INDEX[klass.HEREDOC_STRING_ALTERNATIVE],
)
klass._RS_TOKENIZER = RsTokenizer(settings, token_types)
else:
@@ -573,6 +575,12 @@ class Tokenizer(metaclass=_Tokenizer):
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
+ # Whether or not the heredoc tags follow the same lexical rules as unquoted identifiers
+ HEREDOC_TAG_IS_IDENTIFIER = False
+
+ # Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc
+ HEREDOC_STRING_ALTERNATIVE = TokenType.VAR
+
# Autofilled
_COMMENTS: t.Dict[str, str] = {}
_FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {}
@@ -1249,6 +1257,18 @@ class Tokenizer(metaclass=_Tokenizer):
elif token_type == TokenType.BIT_STRING:
base = 2
elif token_type == TokenType.HEREDOC_STRING:
+ if (
+ self.HEREDOC_TAG_IS_IDENTIFIER
+ and not self._peek.isidentifier()
+ and not self._peek == end
+ ):
+ if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR:
+ self._add(self.HEREDOC_STRING_ALTERNATIVE)
+ else:
+ self._scan_var()
+
+ return True
+
self._advance()
tag = "" if self._char == end else self._extract_string(end)
end = f"{start}{tag}{end}"