summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
commitbea2635be022e272ddac349f5e396ec901fc37e5 (patch)
tree24dbe11c9d462ff55f9b3af4b4da4cd1ae02e8a3
parentReleasing debian version 10.1.3-1. (diff)
downloadsqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.tar.xz
sqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.zip
Merging upstream version 10.2.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r--CHANGELOG.md37
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/dataframe.py2
-rw-r--r--sqlglot/dialects/bigquery.py33
-rw-r--r--sqlglot/dialects/hive.py15
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/redshift.py10
-rw-r--r--sqlglot/dialects/snowflake.py1
-rw-r--r--sqlglot/executor/env.py9
-rw-r--r--sqlglot/executor/python.py4
-rw-r--r--sqlglot/expressions.py106
-rw-r--r--sqlglot/generator.py452
-rw-r--r--sqlglot/helper.py8
-rw-r--r--sqlglot/optimizer/annotate_types.py37
-rw-r--r--sqlglot/optimizer/canonicalize.py25
-rw-r--r--sqlglot/optimizer/simplify.py235
-rw-r--r--sqlglot/parser.py136
-rw-r--r--sqlglot/schema.py22
-rw-r--r--sqlglot/tokens.py19
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py5
-rw-r--r--tests/dataframe/unit/test_dataframe_writer.py34
-rw-r--r--tests/dataframe/unit/test_session.py4
-rw-r--r--tests/dialects/test_bigquery.py19
-rw-r--r--tests/dialects/test_dialect.py36
-rw-r--r--tests/dialects/test_hive.py4
-rw-r--r--tests/dialects/test_postgres.py4
-rw-r--r--tests/dialects/test_redshift.py26
-rw-r--r--tests/dialects/test_snowflake.py9
-rw-r--r--tests/fixtures/identity.sql1
-rw-r--r--tests/fixtures/optimizer/canonicalize.sql6
-rw-r--r--tests/fixtures/optimizer/simplify.sql180
-rw-r--r--tests/fixtures/optimizer/tpc-h/tpc-h.sql51
-rw-r--r--tests/test_executor.py21
-rw-r--r--tests/test_optimizer.py155
-rw-r--r--tests/test_schema.py18
-rw-r--r--tests/test_tokens.py47
36 files changed, 1281 insertions, 493 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a439c2c..7dfca94 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,43 @@
Changelog
=========
+v10.2.0
+------
+
+Changes:
+
+- Breaking: types inferred from annotate_types are now DataType objects, instead of DataType.Type.
+
+- New: the optimizer can now simplify [BETWEEN expressions expressed as explicit comparisons](https://github.com/tobymao/sqlglot/commit/e24d0317dfa644104ff21d009b790224bf84d698).
+
+- New: the optimizer now removes redundant casts.
+
+- New: added support for Redshift's ENCODE/DECODE.
+
+- New: the optimizer now [treats identifiers as case-insensitive](https://github.com/tobymao/sqlglot/commit/638ed265f195219d7226f4fbae128f1805ae8988).
+
+- New: the optimizer now [handles nested CTEs](https://github.com/tobymao/sqlglot/commit/1bdd652792889a8aaffb1c6d2c8aa1fe4a066281).
+
+- New: the executor can now execute SELECT DISTINCT expressions.
+
+- New: added support for Redshift's COPY and UNLOAD commands.
+
+- New: added ability to parse LIKE in CREATE TABLE statement.
+
+- New: the optimizer now [unnests scalar subqueries as cross joins](https://github.com/tobymao/sqlglot/commit/4373ad8518ede4ef1fda8b247b648c680a93d12d).
+
+- Improvement: fixed Bigquery's ARRAY function parsing, so that it can now handle a SELECT expression as an argument.
+
+- Improvement: improved Snowflake's [ARRAY and MAP constructs](https://github.com/tobymao/sqlglot/commit/0506657dba55fe71d004c81c907e23cdd2b37d82).
+
+- Improvement: fixed transpilation between STRING_AGG and GROUP_CONCAT.
+
+- Improvement: the INTO clause can now be parsed in SELECT expressions.
+
+- Improvement: improve executor; it currently executes all TPC-H queries up to TPC-H 17 (inclusive).
+
+- Improvement: DISTINCT ON is now transpiled to a SELECT expression from a subquery for Redshift.
+
v10.1.0
------
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index b027ac7..3733b20 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.1.3"
+__version__ = "10.2.6"
pretty = False
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 548c322..3c45741 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -317,7 +317,7 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.name
+ expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
)
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5b44912..6be68ac 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -110,17 +110,17 @@ class BigQuery(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "BEGIN": TokenType.COMMAND,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
- "INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE,
+ "INT64": TokenType.BIGINT,
+ "NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
- "NOT DETERMINISTIC": TokenType.VOLATILE,
- "BEGIN": TokenType.COMMAND,
- "BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@@ -131,6 +131,7 @@ class BigQuery(Dialect):
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
+ "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
@@ -144,6 +145,7 @@ class BigQuery(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
@@ -161,7 +163,6 @@ class BigQuery(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
- exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
@@ -183,6 +184,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
+ exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
}
TYPE_MAPPING = {
@@ -210,24 +212,31 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
- def transaction_sql(self, *_):
+ def array_sql(self, expression: exp.Array) -> str:
+ first_arg = seq_get(expression.expressions, 0)
+ if isinstance(first_arg, exp.Subqueryable):
+ return f"ARRAY{self.wrap(self.sql(first_arg))}"
+
+ return inline_array_sql(self, expression)
+
+ def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"
- def commit_sql(self, *_):
+ def commit_sql(self, *_) -> str:
return "COMMIT TRANSACTION"
- def rollback_sql(self, *_):
+ def rollback_sql(self, *_) -> str:
return "ROLLBACK TRANSACTION"
- def in_unnest_op(self, unnest):
- return self.sql(unnest)
+ def in_unnest_op(self, expression: exp.Unnest) -> str:
+ return self.sql(expression)
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index cbb39c2..70c1c6c 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -190,6 +190,7 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
+ "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
class Parser(parser.Parser):
@@ -238,6 +239,13 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
+ expressions=self._parse_wrapped_csv(self._parse_property)
+ ),
+ }
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -297,6 +305,8 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
+ exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
+ exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
@@ -308,12 +318,15 @@ class Hive(Dialect):
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
+ exp.RowFormatDelimitedProperty,
+ exp.RowFormatSerdeProperty,
+ exp.SerdeProperties,
}
def with_properties(self, properties):
return self.properties(
properties,
- prefix="TBLPROPERTIES",
+ prefix=self.seg("TBLPROPERTIES"),
)
def datatype_sql(self, expression):
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index ceaf9ba..f507513 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -98,6 +98,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "MINUS": TokenType.EXCEPT,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index cd50979..55ed0a6 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from sqlglot import exp, transforms
+from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@@ -13,12 +14,20 @@ class Redshift(Postgres):
"HH": "%H",
}
+ class Parser(Postgres.Parser):
+ FUNCTIONS = {
+ **Postgres.Parser.FUNCTIONS, # type: ignore
+ "DECODE": exp.Matches.from_arg_list,
+ "NVL": exp.Coalesce.from_arg_list,
+ }
+
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
+ "ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
@@ -50,4 +59,5 @@ class Redshift(Postgres):
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
+ exp.Matches: rename_func("DECODE"),
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 46155ff..75dc9dc 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -198,6 +198,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
+ "MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index e6cfcdd..ad9397e 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -19,10 +19,13 @@ class reverse_key:
return other.obj < self.obj
-def filter_nulls(func):
+def filter_nulls(func, empty_null=True):
@wraps(func)
def _func(values):
- return func(v for v in values if v is not None)
+ filtered = tuple(v for v in values if v is not None)
+ if not filtered and empty_null:
+ return None
+ return func(filtered)
return _func
@@ -126,7 +129,7 @@ ENV = {
# aggs
"SUM": filter_nulls(sum),
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
- "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
+ "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
# scalar functions
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 908b80a..9f22c45 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -310,9 +310,9 @@ class PythonExecutor:
if i == length - 1:
context.set_range(start, end - 1)
add_row()
- elif step.limit > 0:
+ elif step.limit > 0 and not group_by:
context.set_range(0, 0)
- table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
+ table.append(context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 96b32f1..7249574 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "type", "comments")
+ __slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
- self.type = None
self.comments = None
+ self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@@ -122,6 +122,16 @@ class Expression(metaclass=_Expression):
return "NULL"
return self.alias or self.name
+ @property
+ def type(self) -> t.Optional[DataType]:
+ return self._type
+
+ @type.setter
+ def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None:
+ if dtype and not isinstance(dtype, DataType):
+ dtype = DataType.build(dtype)
+ self._type = dtype # type: ignore
+
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments
@@ -348,7 +358,7 @@ class Expression(metaclass=_Expression):
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
- args = {
+ args: t.Dict[str, t.Any] = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_collection(vs)
@@ -612,6 +622,7 @@ class Create(Expression):
"properties": False,
"temporary": False,
"transient": False,
+ "external": False,
"replace": False,
"unique": False,
"materialized": False,
@@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind):
pass
+class EncodeColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind):
- pass
+ arg_types = {"allow_null": False}
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
@@ -766,7 +781,7 @@ class Constraint(Expression):
class Delete(Expression):
- arg_types = {"with": False, "this": True, "using": False, "where": False}
+ arg_types = {"with": False, "this": False, "using": False, "where": False}
class Drop(Expression):
@@ -850,7 +865,7 @@ class Insert(Expression):
arg_types = {
"with": False,
"this": True,
- "expression": True,
+ "expression": False,
"overwrite": False,
"exists": False,
"partition": False,
@@ -1125,6 +1140,27 @@ class VolatilityProperty(Property):
arg_types = {"this": True}
+class RowFormatDelimitedProperty(Property):
+ # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
+ arg_types = {
+ "fields": False,
+ "escaped": False,
+ "collection_items": False,
+ "map_keys": False,
+ "lines": False,
+ "null": False,
+ "serde": False,
+ }
+
+
+class RowFormatSerdeProperty(Property):
+ arg_types = {"this": True}
+
+
+class SerdeProperties(Property):
+ arg_types = {"expressions": True}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1169,18 +1205,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
-class RowFormat(Expression):
- # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
- arg_types = {
- "fields": False,
- "escaped": False,
- "collection_items": False,
- "map_keys": False,
- "lines": False,
- "null": False,
- }
-
-
class Tuple(Expression):
arg_types = {"expressions": False}
@@ -1208,6 +1232,9 @@ class Subqueryable(Unionable):
alias=TableAlias(this=to_identifier(alias)),
)
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ raise NotImplementedError
+
@property
def ctes(self):
with_ = self.args.get("with")
@@ -1320,6 +1347,32 @@ class Union(Subqueryable):
**QUERY_MODIFIERS,
}
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ """
+ Set the LIMIT expression.
+
+ Example:
+ >>> select("1").union(select("1")).limit(1).sql()
+ 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
+
+ Args:
+ expression (str | int | Expression): the SQL code string to parse.
+ This can also be an integer.
+ If a `Limit` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Limit`.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: The limited subqueryable.
+ """
+ return (
+ select("*")
+ .from_(self.subquery(alias="_l_0", copy=copy))
+ .limit(expression, dialect=dialect, copy=False, **opts)
+ )
+
@property
def named_selects(self):
return self.this.unnest().named_selects
@@ -1356,7 +1409,7 @@ class Unnest(UDTF):
class Update(Expression):
arg_types = {
"with": False,
- "this": True,
+ "this": False,
"expressions": True,
"from": False,
"where": False,
@@ -2057,15 +2110,20 @@ class DataType(Expression):
Type.TEXT,
}
- NUMERIC_TYPES = {
+ INTEGER_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
+ }
+
+ FLOAT_TYPES = {
Type.FLOAT,
Type.DOUBLE,
}
+ NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
+
TEMPORAL_TYPES = {
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
@@ -2968,6 +3026,14 @@ class Use(Expression):
pass
+class Merge(Expression):
+ arg_types = {"this": True, "using": True, "on": True, "expressions": True}
+
+
+class When(Func):
+ arg_types = {"this": True, "then": True}
+
+
def _norm_args(expression):
args = {}
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 47774fc..beffb91 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -189,12 +189,12 @@ class Generator:
self._max_text_width = max_text_width
self._comments = comments
- def generate(self, expression):
+ def generate(self, expression: t.Optional[exp.Expression]) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
- expression (Expression): the syntax tree.
+ expression: the syntax tree.
Returns
the SQL string.
@@ -213,23 +213,23 @@ class Generator:
return sql
- def unsupported(self, message):
+ def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
- def sep(self, sep=" "):
+ def sep(self, sep: str = " ") -> str:
return f"{sep.strip()}\n" if self.pretty else sep
- def seg(self, sql, sep=" "):
+ def seg(self, sql: str, sep: str = " ") -> str:
return f"{self.sep(sep)}{sql}"
- def pad_comment(self, comment):
+ def pad_comment(self, comment: str) -> str:
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
- def maybe_comment(self, sql, expression):
+ def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
comments = expression.comments if self._comments else None
if not comments:
@@ -243,7 +243,7 @@ class Generator:
return f"{sql} {comments}"
- def wrap(self, expression):
+ def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
@@ -253,21 +253,28 @@ class Generator:
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
- def no_identify(self, func):
+ def no_identify(self, func: t.Callable[[], str]) -> str:
original = self.identify
self.identify = False
result = func()
self.identify = original
return result
- def normalize_func(self, name):
+ def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper":
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
return name
- def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False):
+ def indent(
+ self,
+ sql: str,
+ level: int = 0,
+ pad: t.Optional[int] = None,
+ skip_first: bool = False,
+ skip_last: bool = False,
+ ) -> str:
if not self.pretty:
return sql
@@ -281,7 +288,12 @@ class Generator:
for i, line in enumerate(lines)
)
- def sql(self, expression, key=None, comment=True):
+ def sql(
+ self,
+ expression: t.Optional[str | exp.Expression],
+ key: t.Optional[str] = None,
+ comment: bool = True,
+ ) -> str:
if not expression:
return ""
@@ -313,12 +325,12 @@ class Generator:
return self.maybe_comment(sql, expression) if self._comments and comment else sql
- def uncache_sql(self, expression):
+ def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
exists_sql = " IF EXISTS" if expression.args.get("exists") else ""
return f"UNCACHE TABLE{exists_sql} {table}"
- def cache_sql(self, expression):
+ def cache_sql(self, expression: exp.Cache) -> str:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
@@ -328,13 +340,13 @@ class Generator:
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
return self.prepend_ctes(expression, sql)
- def characterset_sql(self, expression):
+ def characterset_sql(self, expression: exp.CharacterSet) -> str:
if isinstance(expression.parent, exp.Cast):
return f"CHAR CHARACTER SET {self.sql(expression, 'this')}"
default = "DEFAULT " if expression.args.get("default") else ""
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
- def column_sql(self, expression):
+ def column_sql(self, expression: exp.Column) -> str:
return ".".join(
part
for part in [
@@ -345,7 +357,7 @@ class Generator:
if part
)
- def columndef_sql(self, expression):
+ def columndef_sql(self, expression: exp.ColumnDef) -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
@@ -354,46 +366,52 @@ class Generator:
return f"{column} {kind}"
return f"{column} {kind} {constraints}"
- def columnconstraint_sql(self, expression):
+ def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
kind_sql = self.sql(expression, "kind")
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
- def autoincrementcolumnconstraint_sql(self, _):
+ def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
- def checkcolumnconstraint_sql(self, expression):
+ def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
this = self.sql(expression, "this")
return f"CHECK ({this})"
- def commentcolumnconstraint_sql(self, expression):
+ def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str:
comment = self.sql(expression, "this")
return f"COMMENT {comment}"
- def collatecolumnconstraint_sql(self, expression):
+ def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str:
collate = self.sql(expression, "this")
return f"COLLATE {collate}"
- def defaultcolumnconstraint_sql(self, expression):
+ def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str:
+ encode = self.sql(expression, "this")
+ return f"ENCODE {encode}"
+
+ def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
- def generatedasidentitycolumnconstraint_sql(self, expression):
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
- def notnullcolumnconstraint_sql(self, _):
- return "NOT NULL"
+ def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
+ return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
- def primarykeycolumnconstraint_sql(self, expression):
+ def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
- def uniquecolumnconstraint_sql(self, _):
+ def uniquecolumnconstraint_sql(self, _) -> str:
return "UNIQUE"
- def create_sql(self, expression):
+ def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
expression_sql = self.sql(expression, "expression")
@@ -402,47 +420,58 @@ class Generator:
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
+ external = " EXTERNAL" if expression.args.get("external") else ""
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
- expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
+ modifiers = "".join(
+ (
+ replace,
+ temporary,
+ transient,
+ external,
+ unique,
+ materialized,
+ )
+ )
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
return self.prepend_ctes(expression, expression_sql)
- def describe_sql(self, expression):
+ def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
- def prepend_ctes(self, expression, sql):
+ def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
with_ = self.sql(expression, "with")
if with_:
sql = f"{with_}{self.sep()}{sql}"
return sql
- def with_sql(self, expression):
+ def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
return f"WITH {recursive}{sql}"
- def cte_sql(self, expression):
+ def cte_sql(self, expression: exp.CTE) -> str:
alias = self.sql(expression, "alias")
return f"{alias} AS {self.wrap(expression)}"
- def tablealias_sql(self, expression):
+ def tablealias_sql(self, expression: exp.TableAlias) -> str:
alias = self.sql(expression, "this")
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
return f"{alias}{columns}"
- def bitstring_sql(self, expression):
+ def bitstring_sql(self, expression: exp.BitString) -> str:
return self.sql(expression, "this")
- def hexstring_sql(self, expression):
+ def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
- def datatype_sql(self, expression):
+ def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
@@ -455,13 +484,13 @@ class Generator:
)
return f"{type_sql}{nested}"
- def directory_sql(self, expression):
+ def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
row_format = self.sql(expression, "row_format")
row_format = f" {row_format}" if row_format else ""
return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
- def delete_sql(self, expression):
+ def delete_sql(self, expression: exp.Delete) -> str:
this = self.sql(expression, "this")
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
@@ -472,7 +501,7 @@ class Generator:
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
- def drop_sql(self, expression):
+ def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
@@ -481,46 +510,46 @@ class Generator:
cascade = " CASCADE" if expression.args.get("cascade") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
- def except_sql(self, expression):
+ def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.except_op(expression)),
)
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
- def fetch_sql(self, expression):
+ def fetch_sql(self, expression: exp.Fetch) -> str:
direction = expression.args.get("direction")
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
- def filter_sql(self, expression):
+ def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
where = self.sql(expression, "expression")[1:] # where has a leading space
return f"{this} FILTER({where})"
- def hint_sql(self, expression):
+ def hint_sql(self, expression: exp.Hint) -> str:
if self.sql(expression, "this"):
self.unsupported("Hints are not supported")
return ""
- def index_sql(self, expression):
+ def index_sql(self, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON {table} {columns}"
- def identifier_sql(self, expression):
+ def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
- def partition_sql(self, expression):
+ def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
@@ -529,7 +558,7 @@ class Generator:
)
return f"PARTITION({keys})"
- def properties_sql(self, expression):
+ def properties_sql(self, expression: exp.Properties) -> str:
root_properties = []
with_properties = []
@@ -544,21 +573,21 @@ class Generator:
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
- def root_properties(self, properties):
+ def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
- def properties(self, properties, prefix="", sep=", "):
+ def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
+ return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
- def with_properties(self, properties):
- return self.properties(properties, prefix="WITH")
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("WITH"))
- def property_sql(self, expression):
+ def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"
@@ -569,12 +598,12 @@ class Generator:
return f"{property_name}={self.sql(expression, 'this')}"
- def likeproperty_sql(self, expression):
+ def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
- def insert_sql(self, expression):
+ def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
if isinstance(expression.this, exp.Directory):
@@ -592,19 +621,19 @@ class Generator:
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
return self.prepend_ctes(expression, sql)
- def intersect_sql(self, expression):
+ def intersect_sql(self, expression: exp.Intersect) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.intersect_op(expression)),
)
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
- def introducer_sql(self, expression):
+ def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
- def rowformat_sql(self, expression):
+ def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
escaped = expression.args.get("escaped")
@@ -619,7 +648,7 @@ class Generator:
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
- def table_sql(self, expression, sep=" AS "):
+ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
part
for part in [
@@ -642,7 +671,7 @@ class Generator:
return f"{table}{alias}{laterals}{joins}{pivots}"
- def tablesample_sql(self, expression):
+ def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
@@ -665,7 +694,7 @@ class Generator:
seed = f" SEED ({seed})" if seed else ""
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
- def pivot_sql(self, expression):
+ def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
@@ -673,10 +702,10 @@ class Generator:
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
- def tuple_sql(self, expression):
+ def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
- def update_sql(self, expression):
+ def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
set_sql = self.expressions(expression, flat=True)
from_sql = self.sql(expression, "from")
@@ -684,7 +713,7 @@ class Generator:
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
- def values_sql(self, expression):
+ def values_sql(self, expression: exp.Values) -> str:
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
@@ -694,19 +723,19 @@ class Generator:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
- def var_sql(self, expression):
+ def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
- def into_sql(self, expression):
+ def into_sql(self, expression: exp.Into) -> str:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
- def from_sql(self, expression):
+ def from_sql(self, expression: exp.From) -> str:
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
- def group_sql(self, expression):
+ def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
@@ -718,11 +747,11 @@ class Generator:
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
return f"{group_by}{grouping_sets}{cube}{rollup}"
- def having_sql(self, expression):
+ def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"
- def join_sql(self, expression):
+ def join_sql(self, expression: exp.Join) -> str:
op_sql = self.seg(
" ".join(
op
@@ -753,12 +782,12 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
- def lambda_sql(self, expression, arrow_sep="->"):
+ def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
- def lateral_sql(self, expression):
+ def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
@@ -776,15 +805,15 @@ class Generator:
return f"LATERAL {this}{table}{columns}"
- def limit_sql(self, expression):
+ def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
- def offset_sql(self, expression):
+ def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
- def literal_sql(self, expression):
+ def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
if self._replace_backslash:
@@ -793,7 +822,7 @@ class Generator:
text = f"{self.quote_start}{text}{self.quote_end}"
return text
- def loaddata_sql(self, expression):
+ def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
inpath = f" INPATH {self.sql(expression, 'inpath')}"
overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
@@ -806,27 +835,27 @@ class Generator:
serde = f" SERDE {serde}" if serde else ""
return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
- def null_sql(self, *_):
+ def null_sql(self, *_) -> str:
return "NULL"
- def boolean_sql(self, expression):
+ def boolean_sql(self, expression: exp.Boolean) -> str:
return "TRUE" if expression.this else "FALSE"
- def order_sql(self, expression, flat=False):
+ def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
- return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat)
+ return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
- def cluster_sql(self, expression):
+ def cluster_sql(self, expression: exp.Cluster) -> str:
return self.op_expressions("CLUSTER BY", expression)
- def distribute_sql(self, expression):
+ def distribute_sql(self, expression: exp.Distribute) -> str:
return self.op_expressions("DISTRIBUTE BY", expression)
- def sort_sql(self, expression):
+ def sort_sql(self, expression: exp.Sort) -> str:
return self.op_expressions("SORT BY", expression)
- def ordered_sql(self, expression):
+ def ordered_sql(self, expression: exp.Ordered) -> str:
desc = expression.args.get("desc")
asc = not desc
@@ -857,7 +886,7 @@ class Generator:
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
- def query_modifiers(self, expression, *sqls):
+ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
@@ -876,7 +905,7 @@ class Generator:
sep="",
)
- def select_sql(self, expression):
+ def select_sql(self, expression: exp.Select) -> str:
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@@ -890,36 +919,36 @@ class Generator:
)
return self.prepend_ctes(expression, sql)
- def schema_sql(self, expression):
+ def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
return f"{this}{sql}"
- def star_sql(self, expression):
+ def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
return f"*{except_}{replace}"
- def structkwarg_sql(self, expression):
+ def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
- def parameter_sql(self, expression):
+ def parameter_sql(self, expression: exp.Parameter) -> str:
return f"@{self.sql(expression, 'this')}"
- def sessionparameter_sql(self, expression):
+ def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
- def placeholder_sql(self, expression):
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f":{expression.name}" if expression.name else "?"
- def subquery_sql(self, expression):
+ def subquery_sql(self, expression: exp.Subquery) -> str:
alias = self.sql(expression, "alias")
sql = self.query_modifiers(
@@ -931,22 +960,22 @@ class Generator:
return self.prepend_ctes(expression, sql)
- def qualify_sql(self, expression):
+ def qualify_sql(self, expression: exp.Qualify) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
- def union_sql(self, expression):
+ def union_sql(self, expression: exp.Union) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.union_op(expression)),
)
- def union_op(self, expression):
+ def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
return f"UNION{kind}"
- def unnest_sql(self, expression):
+ def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
if alias and self.unnest_column_only:
@@ -958,11 +987,11 @@ class Generator:
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
return f"UNNEST({args}){ordinality}{alias}"
- def where_sql(self, expression):
+ def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('WHERE')}{self.sep()}{this}"
- def window_sql(self, expression):
+ def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
partition = self.expressions(expression, key="partition_by", flat=True)
@@ -988,7 +1017,7 @@ class Generator:
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
- def window_spec_sql(self, expression):
+ def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
@@ -997,33 +1026,33 @@ class Generator:
)
return f"{kind} BETWEEN {start} AND {end}"
- def withingroup_sql(self, expression):
+ def withingroup_sql(self, expression: exp.WithinGroup) -> str:
this = self.sql(expression, "this")
- expression = self.sql(expression, "expression")[1:] # order has a leading space
- return f"{this} WITHIN GROUP ({expression})"
+ expression_sql = self.sql(expression, "expression")[1:] # order has a leading space
+ return f"{this} WITHIN GROUP ({expression_sql})"
- def between_sql(self, expression):
+ def between_sql(self, expression: exp.Between) -> str:
this = self.sql(expression, "this")
low = self.sql(expression, "low")
high = self.sql(expression, "high")
return f"{this} BETWEEN {low} AND {high}"
- def bracket_sql(self, expression):
+ def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.expressions, self.index_offset)
- expressions = ", ".join(self.sql(e) for e in expressions)
+ expressions_sql = ", ".join(self.sql(e) for e in expressions)
- return f"{self.sql(expression, 'this')}[{expressions}]"
+ return f"{self.sql(expression, 'this')}[{expressions_sql}]"
- def all_sql(self, expression):
+ def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
- def any_sql(self, expression):
+ def any_sql(self, expression: exp.Any) -> str:
return f"ANY {self.wrap(expression)}"
- def exists_sql(self, expression):
+ def exists_sql(self, expression: exp.Exists) -> str:
return f"EXISTS{self.wrap(expression)}"
- def case_sql(self, expression):
+ def case_sql(self, expression: exp.Case) -> str:
this = self.sql(expression, "this")
statements = [f"CASE {this}" if this else "CASE"]
@@ -1043,17 +1072,17 @@ class Generator:
return " ".join(statements)
- def constraint_sql(self, expression):
+ def constraint_sql(self, expression: exp.Constraint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
- def extract_sql(self, expression):
+ def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
- def trim_sql(self, expression):
+ def trim_sql(self, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
@@ -1064,16 +1093,16 @@ class Generator:
else:
return f"TRIM({target})"
- def concat_sql(self, expression):
+ def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
- def check_sql(self, expression):
+ def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
return f"CHECK ({this})"
- def foreignkey_sql(self, expression):
+ def foreignkey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
reference = self.sql(expression, "reference")
reference = f" {reference}" if reference else ""
@@ -1083,16 +1112,16 @@ class Generator:
update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
- def unique_sql(self, expression):
+ def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})"
- def if_sql(self, expression):
+ def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
- def in_sql(self, expression):
+ def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
@@ -1106,24 +1135,24 @@ class Generator:
in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}"
- def in_unnest_op(self, unnest):
+ def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})"
- def interval_sql(self, expression):
+ def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
- def reference_sql(self, expression):
+ def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"REFERENCES {this}({expressions})"
- def anonymous_sql(self, expression):
+ def anonymous_sql(self, expression: exp.Anonymous) -> str:
args = self.format_args(*expression.expressions)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
- def paren_sql(self, expression):
+ def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
sql = self.wrap(expression)
else:
@@ -1132,35 +1161,35 @@ class Generator:
return self.prepend_ctes(expression, sql)
- def neg_sql(self, expression):
+ def neg_sql(self, expression: exp.Neg) -> str:
# This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
- def not_sql(self, expression):
+ def not_sql(self, expression: exp.Not) -> str:
return f"NOT {self.sql(expression, 'this')}"
- def alias_sql(self, expression):
+ def alias_sql(self, expression: exp.Alias) -> str:
to_sql = self.sql(expression, "alias")
to_sql = f" AS {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
- def aliases_sql(self, expression):
+ def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
- def attimezone_sql(self, expression):
+ def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone}"
- def add_sql(self, expression):
+ def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
- def and_sql(self, expression):
+ def and_sql(self, expression: exp.And) -> str:
return self.connector_sql(expression, "AND")
- def connector_sql(self, expression, op):
+ def connector_sql(self, expression: exp.Connector, op: str) -> str:
if not self.pretty:
return self.binary(expression, op)
@@ -1168,53 +1197,53 @@ class Generator:
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
return f"{sep}{op} ".join(sqls)
- def bitwiseand_sql(self, expression):
+ def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
return self.binary(expression, "&")
- def bitwiseleftshift_sql(self, expression):
+ def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str:
return self.binary(expression, "<<")
- def bitwisenot_sql(self, expression):
+ def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str:
return f"~{self.sql(expression, 'this')}"
- def bitwiseor_sql(self, expression):
+ def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str:
return self.binary(expression, "|")
- def bitwiserightshift_sql(self, expression):
+ def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str:
return self.binary(expression, ">>")
- def bitwisexor_sql(self, expression):
+ def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
return self.binary(expression, "^")
- def cast_sql(self, expression):
+ def cast_sql(self, expression: exp.Cast) -> str:
return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- def currentdate_sql(self, expression):
+ def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
- def collate_sql(self, expression):
+ def collate_sql(self, expression: exp.Collate) -> str:
return self.binary(expression, "COLLATE")
- def command_sql(self, expression):
+ def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
- def transaction_sql(self, *_):
+ def transaction_sql(self, *_) -> str:
return "BEGIN"
- def commit_sql(self, expression):
+ def commit_sql(self, expression: exp.Commit) -> str:
chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
- def rollback_sql(self, expression):
+ def rollback_sql(self, expression: exp.Rollback) -> str:
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
- def distinct_sql(self, expression):
+ def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1222,13 +1251,13 @@ class Generator:
on = f" ON {on}" if on else ""
return f"DISTINCT{this}{on}"
- def ignorenulls_sql(self, expression):
+ def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
return f"{self.sql(expression, 'this')} IGNORE NULLS"
- def respectnulls_sql(self, expression):
+ def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
- def intdiv_sql(self, expression):
+ def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this, expression=expression.expression),
@@ -1236,79 +1265,79 @@ class Generator:
)
)
- def dpipe_sql(self, expression):
+ def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
- def div_sql(self, expression):
+ def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
- def distance_sql(self, expression):
+ def distance_sql(self, expression: exp.Distance) -> str:
return self.binary(expression, "<->")
- def dot_sql(self, expression):
+ def dot_sql(self, expression: exp.Dot) -> str:
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
- def eq_sql(self, expression):
+ def eq_sql(self, expression: exp.EQ) -> str:
return self.binary(expression, "=")
- def escape_sql(self, expression):
+ def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
- def gt_sql(self, expression):
+ def gt_sql(self, expression: exp.GT) -> str:
return self.binary(expression, ">")
- def gte_sql(self, expression):
+ def gte_sql(self, expression: exp.GTE) -> str:
return self.binary(expression, ">=")
- def ilike_sql(self, expression):
+ def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
- def is_sql(self, expression):
+ def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
- def like_sql(self, expression):
+ def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
- def similarto_sql(self, expression):
+ def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")
- def lt_sql(self, expression):
+ def lt_sql(self, expression: exp.LT) -> str:
return self.binary(expression, "<")
- def lte_sql(self, expression):
+ def lte_sql(self, expression: exp.LTE) -> str:
return self.binary(expression, "<=")
- def mod_sql(self, expression):
+ def mod_sql(self, expression: exp.Mod) -> str:
return self.binary(expression, "%")
- def mul_sql(self, expression):
+ def mul_sql(self, expression: exp.Mul) -> str:
return self.binary(expression, "*")
- def neq_sql(self, expression):
+ def neq_sql(self, expression: exp.NEQ) -> str:
return self.binary(expression, "<>")
- def nullsafeeq_sql(self, expression):
+ def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str:
return self.binary(expression, "IS NOT DISTINCT FROM")
- def nullsafeneq_sql(self, expression):
+ def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
return self.binary(expression, "IS DISTINCT FROM")
- def or_sql(self, expression):
+ def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
- def sub_sql(self, expression):
+ def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
- def trycast_sql(self, expression):
+ def trycast_sql(self, expression: exp.TryCast) -> str:
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- def use_sql(self, expression):
+ def use_sql(self, expression: exp.Use) -> str:
return f"USE {self.sql(expression, 'this')}"
- def binary(self, expression, op):
+ def binary(self, expression: exp.Binary, op: str) -> str:
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
- def function_fallback_sql(self, expression):
+ def function_fallback_sql(self, expression: exp.Func) -> str:
args = []
for arg_value in expression.args.values():
if isinstance(arg_value, list):
@@ -1319,19 +1348,26 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
- def format_args(self, *args):
- args = tuple(self.sql(arg) for arg in args if arg is not None)
- if self.pretty and self.text_width(args) > self._max_text_width:
- return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True)
- return ", ".join(args)
+ def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
+ arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
+ if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
+ return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
+ return ", ".join(arg_sqls)
- def text_width(self, args):
+ def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
- def format_time(self, expression):
+ def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
- def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
+ def expressions(
+ self,
+ expression: exp.Expression,
+ key: t.Optional[str] = None,
+ flat: bool = False,
+ indent: bool = True,
+ sep: str = ", ",
+ ) -> str:
expressions = expression.args.get(key or "expressions")
if not expressions:
@@ -1359,45 +1395,67 @@ class Generator:
else:
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
- result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
- return self.indent(result_sqls, skip_first=False) if indent else result_sqls
+ result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
+ return self.indent(result_sql, skip_first=False) if indent else result_sql
- def op_expressions(self, op, expression, flat=False):
+ def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
expressions_sql = self.expressions(expression, flat=flat)
if flat:
return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
- def naked_property(self, expression):
+ def naked_property(self, expression: exp.Property) -> str:
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
- def set_operation(self, expression, op):
+ def set_operation(self, expression: exp.Expression, op: str) -> str:
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
- def token_sql(self, token_type):
+ def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name)
- def userdefinedfunction_sql(self, expression):
+ def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})"
- def userdefinedfunctionkwarg_sql(self, expression):
+ def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"
- def joinhint_sql(self, expression):
+ def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
- def kwarg_sql(self, expression):
+ def kwarg_sql(self, expression: exp.Kwarg) -> str:
return self.binary(expression, "=>")
+
+ def when_sql(self, expression: exp.When) -> str:
+ this = self.sql(expression, "this")
+ then_expression = expression.args.get("then")
+ if isinstance(then_expression, exp.Insert):
+ then = f"INSERT {self.sql(then_expression, 'this')}"
+ if "expression" in then_expression.args:
+ then += f" VALUES {self.sql(then_expression, 'expression')}"
+ elif isinstance(then_expression, exp.Update):
+ if isinstance(then_expression.args.get("expressions"), exp.Star):
+ then = f"UPDATE {self.sql(then_expression, 'expressions')}"
+ else:
+ then = f"UPDATE SET {self.expressions(then_expression, flat=True)}"
+ else:
+ then = self.sql(then_expression)
+ return f"WHEN {this} THEN {then}"
+
+ def merge_sql(self, expression: exp.Merge) -> str:
+ this = self.sql(expression, "this")
+ using = f"USING {self.sql(expression, 'using')}"
+ on = f"ON {self.sql(expression, 'on')}"
+ return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}"
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 8c5808d..ed37e6c 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int:
except StopIteration:
# d.values() returns an empty sequence
return 1
+
+
+def first(it: t.Iterable[T]) -> T:
+ """Returns the first element from an iterable.
+
+ Useful for sets.
+ """
+ return next(i for i in it)
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 191ea52..be17f15 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
- >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
+ >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Args:
@@ -41,9 +41,12 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
- exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
+ exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
+ exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
+ exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
+ exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
@@ -52,6 +55,9 @@ class TypeAnnotator:
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
+ exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
+ exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
@@ -263,10 +269,10 @@ class TypeAnnotator:
}
# First annotate the current scope's column references
for col in scope.columns:
- source = scope.sources[col.table]
+ source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
- else:
+ elif source:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@@ -280,6 +286,7 @@ class TypeAnnotator:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__)
+
return (
annotator(self, expression)
if annotator
@@ -295,18 +302,23 @@ class TypeAnnotator:
def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
+ if isinstance(type1, exp.DataType):
+ type1 = type1.this
+ if isinstance(type2, exp.DataType):
+ type2 = type2.this
+
if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to[type1] else type1
+ return type2 if type2 in self.coerces_to.get(type1, {}) else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
- left_type = expression.left.type
- right_type = expression.right.type
+ left_type = expression.left.type.this
+ right_type = expression.right.type.this
if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@@ -348,7 +360,7 @@ class TypeAnnotator:
expression.type = target_type
return self._annotate_args(expression)
- def _annotate_by_args(self, expression, *args):
+ def _annotate_by_args(self, expression, *args, promote=False):
self._annotate_args(expression)
expressions = []
for arg in args:
@@ -360,4 +372,11 @@ class TypeAnnotator:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
+
+ if promote:
+ if expression.type.this in exp.DataType.INTEGER_TYPES:
+ expression.type = exp.DataType.Type.BIGINT
+ elif expression.type.this in exp.DataType.FLOAT_TYPES:
+ expression.type = exp.DataType.Type.DOUBLE
+
return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index 9b3d98a..33529a5 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
+
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
+ expression = remove_redundant_casts(expression)
+
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
@@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
- if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
+def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
+ if (
+ isinstance(expression, exp.Cast)
+ and expression.to.type
+ and expression.this.type
+ and expression.to.type.this == expression.this.type.this
+ ):
+ return expression.this
+ return expression
+
+
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
- if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ if (
+ a.type
+ and a.type.this == exp.DataType.Type.DATE
+ and b.type
+ and b.type.this != exp.DataType.Type.DATE
+ ):
_replace_cast(b, "date")
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c432c59..c0719f2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -7,7 +7,7 @@ from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
-from sqlglot.helper import while_changing
+from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True)
@@ -30,6 +30,7 @@ def simplify(expression):
def _simplify(expression, root=True):
node = expression
+ node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
@@ -49,6 +50,19 @@ def simplify(expression):
return expression
+def rewrite_between(expression: exp.Expression) -> exp.Expression:
+ """Rewrite x between y and z to x >= y AND x <= z.
+
+ This is done because comparison simplification is only done on lt/lte/gt/gte.
+ """
+ if isinstance(expression, exp.Between):
+ return exp.and_(
+ exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
+ exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
+ )
+ return expression
+
+
def simplify_not(expression):
"""
Demorgan's Law
@@ -57,7 +71,7 @@ def simplify_not(expression):
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
- return NULL
+ return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
@@ -65,11 +79,11 @@ def simplify_not(expression):
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
- return NULL
+ return exp.null()
if always_true(expression.this):
- return FALSE
+ return exp.false()
if expression.this == FALSE:
- return TRUE
+ return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
@@ -91,40 +105,119 @@ def flatten(expression):
def simplify_connectors(expression):
- if isinstance(expression, exp.Connector):
- left = expression.left
- right = expression.right
-
- if left == right:
- return left
-
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return FALSE
- if NULL in (left, right):
- return NULL
- if always_true(left) and always_true(right):
- return TRUE
- if always_true(left):
- return right
- if always_true(right):
- return left
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return TRUE
- if left == FALSE and right == FALSE:
- return FALSE
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return NULL
- if left == FALSE:
- return right
- if right == FALSE:
+ def _simplify_connectors(expression, left, right):
+ if isinstance(expression, exp.Connector):
+ if left == right:
return left
- return expression
+ if isinstance(expression, exp.And):
+ if FALSE in (left, right):
+ return exp.false()
+ if NULL in (left, right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
+ return left
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if left == FALSE and right == FALSE:
+ return exp.false()
+ if (
+ (left == NULL and right == NULL)
+ or (left == NULL and right == FALSE)
+ or (left == FALSE and right == NULL)
+ ):
+ return exp.null()
+ if left == FALSE:
+ return right
+ if right == FALSE:
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
+ return None
+
+ return _flat_simplify(expression, _simplify_connectors)
+
+
+LT_LTE = (exp.LT, exp.LTE)
+GT_GTE = (exp.GT, exp.GTE)
+
+COMPARISONS = (
+ *LT_LTE,
+ *GT_GTE,
+ exp.EQ,
+ exp.NEQ,
+)
+
+INVERSE_COMPARISONS = {
+ exp.LT: exp.GT,
+ exp.GT: exp.LT,
+ exp.LTE: exp.GTE,
+ exp.GTE: exp.LTE,
+}
+
+
+def _simplify_comparison(expression, left, right, or_=False):
+ if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
+ ll, lr = left.args.values()
+ rl, rr = right.args.values()
+
+ largs = {ll, lr}
+ rargs = {rl, rr}
+
+ matching = largs & rargs
+ columns = {m for m in matching if isinstance(m, exp.Column)}
+
+ if matching and columns:
+ try:
+ l = first(largs - columns)
+ r = first(rargs - columns)
+ except StopIteration:
+ return expression
+
+ # make sure the comparison is always of the form x > 1 instead of 1 < x
+ if left.__class__ in INVERSE_COMPARISONS and l == ll:
+ left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
+ if right.__class__ in INVERSE_COMPARISONS and r == rl:
+ right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
+
+ if l.is_number and r.is_number:
+ l = float(l.name)
+ r = float(r.name)
+ elif l.is_string and r.is_string:
+ l = l.name
+ r = r.name
+ else:
+ return None
+
+ for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
+ if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
+ return left if (av > bv if or_ else av <= bv) else right
+ if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
+ return left if (av < bv if or_ else av >= bv) else right
+
+ # we can't ever shortcut to true because the column could be null
+ if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
+ if not or_ and av <= bv:
+ return exp.false()
+ elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
+ if not or_ and av >= bv:
+ return exp.false()
+ elif isinstance(a, exp.EQ):
+ if isinstance(b, exp.LT):
+ return exp.false() if av >= bv else a
+ if isinstance(b, exp.LTE):
+ return exp.false() if av > bv else a
+ if isinstance(b, exp.GT):
+ return exp.false() if av <= bv else a
+ if isinstance(b, exp.GTE):
+ return exp.false() if av < bv else a
+ if isinstance(b, exp.NEQ):
+ return exp.false() if av == bv else a
+ return None
def remove_compliments(expression):
@@ -135,7 +228,7 @@ def remove_compliments(expression):
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
- compliment = FALSE if isinstance(expression, exp.And) else TRUE
+ compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
@@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
-
- while queue:
- a = queue.popleft()
-
- for b in queue:
- result = _simplify_binary(expression, a, b)
-
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
-
- if len(operands) < size:
- return functools.reduce(
- lambda a, b: expression.__class__(this=a, expression=b), operands
- )
+ return _flat_simplify(expression, _simplify_binary)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
if c == NULL:
if isinstance(a, exp.Literal):
- return TRUE if not_ else FALSE
+ return exp.true() if not_ else exp.false()
if a == NULL:
- return FALSE if not_ else TRUE
- elif isinstance(expression, exp.NullSafeEQ):
- if a == b:
- return TRUE
- elif isinstance(expression, exp.NullSafeNEQ):
- if a == b:
- return FALSE
+ return exp.false() if not_ else exp.true()
+ elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
+ return None
elif NULL in (a, b):
- return NULL
-
- if isinstance(expression, exp.EQ) and a == b:
- return TRUE
+ return exp.null()
if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
@@ -388,4 +454,27 @@ def date_literal(date):
def boolean_literal(condition):
- return TRUE if condition else FALSE
+ return exp.true() if condition else exp.false()
+
+
+def _flat_simplify(expression, simplifier):
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
+
+ while queue:
+ a = queue.popleft()
+
+ for b in queue:
+ result = simplifier(expression, a, b)
+
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
+
+ if len(operands) < size:
+ return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index bdf0d2d..55ab453 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -185,6 +185,7 @@ class Parser(metaclass=_Parser):
TokenType.LOCAL,
TokenType.LOCATION,
TokenType.MATERIALIZED,
+ TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
@@ -211,7 +212,6 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
- TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
TokenType.TRUE,
@@ -229,6 +229,8 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
+ UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
+
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
@@ -241,6 +243,7 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT,
TokenType.IDENTIFIER,
TokenType.ISNULL,
+ TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
TokenType.REPLACE,
@@ -407,6 +410,7 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
+ TokenType.MERGE: lambda self: self._parse_merge(),
}
UNARY_PARSERS = {
@@ -474,6 +478,7 @@ class Parser(metaclass=_Parser):
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
+ TokenType.ROW: lambda self: self._parse_row(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
@@ -495,6 +500,8 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
+ TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property),
+ TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property),
}
CONSTRAINT_PARSERS = {
@@ -802,7 +809,8 @@ class Parser(metaclass=_Parser):
def _parse_create(self):
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
- transient = self._match(TokenType.TRANSIENT)
+ transient = self._match_text_seq("TRANSIENT")
+ external = self._match_text_seq("EXTERNAL")
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
@@ -846,6 +854,7 @@ class Parser(metaclass=_Parser):
properties=properties,
temporary=temporary,
transient=transient,
+ external=external,
replace=replace,
unique=unique,
materialized=materialized,
@@ -861,8 +870,12 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
- if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
- key = self._parse_var()
+ assignment = self._match_pair(
+ TokenType.VAR, TokenType.EQ, advance=False
+ ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
+
+ if assignment:
+ key = self._parse_var() or self._parse_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
@@ -871,7 +884,10 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class):
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
- return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
+ return self.expression(
+ exp_class,
+ this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
+ )
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
@@ -881,7 +897,7 @@ class Parser(metaclass=_Parser):
)
def _parse_distkey(self):
- return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
+ return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
def _parse_create_like(self):
table = self._parse_table(schema=True)
@@ -898,7 +914,7 @@ class Parser(metaclass=_Parser):
def _parse_sortkey(self, compound=False):
return self.expression(
- exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
+ exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
)
def _parse_character_set(self, default=False):
@@ -929,23 +945,11 @@ class Parser(metaclass=_Parser):
properties = []
while True:
- if self._match(TokenType.WITH):
- properties.extend(self._parse_wrapped_csv(self._parse_property))
- elif self._match(TokenType.PROPERTIES):
- properties.extend(
- self._parse_wrapped_csv(
- lambda: self.expression(
- exp.Property,
- this=self._parse_string(),
- value=self._match(TokenType.EQ) and self._parse_string(),
- )
- )
- )
- else:
- identified_property = self._parse_property()
- if not identified_property:
- break
- properties.append(identified_property)
+ identified_property = self._parse_property()
+ if not identified_property:
+ break
+ for p in ensure_collection(identified_property):
+ properties.append(p)
if properties:
return self.expression(exp.Properties, expressions=properties)
@@ -963,7 +967,7 @@ class Parser(metaclass=_Parser):
exp.Directory,
this=self._parse_var_or_string(),
local=local,
- row_format=self._parse_row_format(),
+ row_format=self._parse_row_format(match_row=True),
)
else:
self._match(TokenType.INTO)
@@ -978,10 +982,18 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
)
- def _parse_row_format(self):
- if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
+ def _parse_row(self):
+ if not self._match(TokenType.FORMAT):
+ return None
+ return self._parse_row_format()
+
+ def _parse_row_format(self, match_row=False):
+ if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
+ if self._match_text_seq("SERDE"):
+ return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
+
self._match_text_seq("DELIMITED")
kwargs = {}
@@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser):
kwargs["lines"] = self._parse_string()
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
- return self.expression(exp.RowFormat, **kwargs)
+ return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
@@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Update,
**{
- "this": self._parse_table(schema=True),
+ "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
"where": self._parse_where(),
@@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser):
alias=alias,
)
- def _parse_table_alias(self):
+ def _parse_table_alias(self, alias_tokens=None):
any_token = self._match(TokenType.ALIAS)
- alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
+ alias = self._parse_id_var(
+ any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
+ )
columns = None
if self._match(TokenType.L_PAREN):
@@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser):
columns=self._parse_expression(),
)
- def _parse_table(self, schema=False):
+ def _parse_table(self, schema=False, alias_tokens=None):
lateral = self._parse_lateral()
if lateral:
@@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser):
table = self._parse_id_var()
if not table:
- self.raise_error("Expected table name")
+ self.raise_error(f"Expected table name but got {self._curr}")
this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser):
if self.alias_post_tablesample:
table_sample = self._parse_table_sample()
- alias = self._parse_table_alias()
+ alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
if alias:
this.set("alias", alias)
@@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
+ elif self._match(TokenType.ENCODE):
+ kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
+ elif self._match(TokenType.NULL):
+ kind = exp.NotNullColumnConstraint(allow_null=True)
elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
@@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser):
return self._parse_window(this)
def _parse_extract(self):
- this = self._parse_var() or self._parse_type()
+ this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
@@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item()
+ def _parse_merge(self):
+ self._match(TokenType.INTO)
+ target = self._parse_table(schema=True)
+
+ self._match(TokenType.USING)
+ using = self._parse_table()
+
+ self._match(TokenType.ON)
+ on = self._parse_conjunction()
+
+ whens = []
+ while self._match(TokenType.WHEN):
+ this = self._parse_conjunction()
+ self._match(TokenType.THEN)
+
+ if self._match(TokenType.INSERT):
+ _this = self._parse_star()
+ if _this:
+ then = self.expression(exp.Insert, this=_this)
+ else:
+ then = self.expression(
+ exp.Insert,
+ this=self._parse_value(),
+ expression=self._match(TokenType.VALUES) and self._parse_value(),
+ )
+ elif self._match(TokenType.UPDATE):
+ expressions = self._parse_star()
+ if expressions:
+ then = self.expression(exp.Update, expressions=expressions)
+ else:
+ then = self.expression(
+ exp.Update,
+ expressions=self._match(TokenType.SET)
+ and self._parse_csv(self._parse_equality),
+ )
+ elif self._match(TokenType.DELETE):
+ then = self.expression(exp.Var, this=self._prev.text)
+
+ whens.append(self.expression(exp.When, this=this, then=then))
+
+ return self.expression(
+ exp.Merge,
+ this=target,
+ using=using,
+ on=on,
+ expressions=whens,
+ )
+
def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f6f303b..8a264a2 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -47,7 +47,7 @@ class Schema(abc.ABC):
"""
@abc.abstractmethod
- def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
+ def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
"""
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
@@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
- self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
- "STR": exp.DataType.Type.TEXT,
+ self._type_mapping_cache: t.Dict[str, exp.DataType] = {
+ "STR": exp.DataType.build("text"),
}
@classmethod
@@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
- def get_column_type(
- self, table: exp.Table | str, column: exp.Column | str
- ) -> exp.DataType.Type:
+ def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
- table_schema = self.find(table_)
- schema_type = table_schema.get(column_name).upper() # type: ignore
- return self._convert_type(schema_type)
+ table_schema = self.find(table_, raise_on_missing=False)
+ if table_schema:
+ schema_type = table_schema.get(column_name).upper() # type: ignore
+ return self._convert_type(schema_type)
+ return exp.DataType(this=exp.DataType.Type.UNKNOWN)
raise SchemaError(f"Could not convert table '{table}'")
- def _convert_type(self, schema_type: str) -> exp.DataType.Type:
+ def _convert_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None:
raise ValueError(f"Could not parse {schema_type}")
- self._type_mapping_cache[schema_type] = expression.this
+ self._type_mapping_cache[schema_type] = expression # type: ignore
except AttributeError:
raise SchemaError(f"Failed to convert type {schema_type}")
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 8a7a38e..b25ef8d 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -49,6 +49,9 @@ class TokenType(AutoName):
PARAMETER = auto()
SESSION_PARAMETER = auto()
+ BLOCK_START = auto()
+ BLOCK_END = auto()
+
SPACE = auto()
BREAK = auto()
@@ -156,6 +159,7 @@ class TokenType(AutoName):
DIV = auto()
DROP = auto()
ELSE = auto()
+ ENCODE = auto()
END = auto()
ENGINE = auto()
ESCAPE = auto()
@@ -207,6 +211,7 @@ class TokenType(AutoName):
LOCATION = auto()
MAP = auto()
MATERIALIZED = auto()
+ MERGE = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
@@ -255,6 +260,7 @@ class TokenType(AutoName):
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
+ SERDE_PROPERTIES = auto()
SET = auto()
SHOW = auto()
SIMILAR_TO = auto()
@@ -267,7 +273,6 @@ class TokenType(AutoName):
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
- TRANSIENT = auto()
TOP = auto()
THEN = auto()
TRAILING = auto()
@@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"]
KEYWORDS = {
+ **{
+ f"{key}{postfix}": TokenType.BLOCK_START
+ for key in ("{{", "{%", "{#")
+ for postfix in ("", "+", "-")
+ },
+ **{
+ f"{prefix}{key}": TokenType.BLOCK_END
+ for key in ("}}", "%}", "#}")
+ for prefix in ("", "+", "-")
+ },
"/*+": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
@@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LOCAL": TokenType.LOCAL,
"LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
+ "MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
@@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer):
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
- "TRANSIENT": TokenType.TRANSIENT,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index 32ff8f2..2dcdb39 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -4,6 +4,7 @@ import unittest
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.helper import ensure_list
class DataFrameSQLValidator(unittest.TestCase):
@@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase):
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty)
- expected_statements = (
- [expected_statements] if isinstance(expected_statements, str) else expected_statements
- )
+ expected_statements = ensure_list(expected_statements)
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)
diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py
index 7c646f5..042b915 100644
--- a/tests/dataframe/unit/test_dataframe_writer.py
+++ b/tests/dataframe/unit/test_dataframe_writer.py
@@ -10,37 +10,37 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
- expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
- expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
- expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t37164",
- "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
+ "DROP VIEW IF EXISTS t12441",
+ "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
@@ -50,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
- expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t37164",
- "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
+ "DROP VIEW IF EXISTS t12441",
+ "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 55aa547..5213667 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -36,7 +36,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
- expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
@@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
- expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index cc44311..1d60ec6 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -7,6 +7,11 @@ class TestBigQuery(Validator):
def test_bigquery(self):
self.validate_all(
+ "REGEXP_CONTAINS('foo', '.*')",
+ read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
+ write={"mysql": "REGEXP_LIKE('foo', '.*')"},
+ ),
+ self.validate_all(
'"""x"""',
write={
"bigquery": "'x'",
@@ -94,6 +99,20 @@ class TestBigQuery(Validator):
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
},
)
+ self.validate_all(
+ "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)",
+ write={"bigquery": "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)"},
+ )
+ self.validate_all(
+ "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers",
+ write={
+ "bigquery": "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers"
+ },
+ )
+ self.validate_all(
+ "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)",
+ write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"},
+ )
self.validate_all(
"x IS unknown",
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 6033570..ee67bf1 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1318,3 +1318,39 @@ SELECT
"BEGIN IMMEDIATE TRANSACTION",
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
)
+
+ def test_merge(self):
+ self.validate_all(
+ """
+ MERGE INTO target USING source ON target.id = source.id
+ WHEN NOT MATCHED THEN INSERT (id) values (source.id)
+ """,
+ write={
+ "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
+ "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
+ "spark": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
+ },
+ )
+ self.validate_all(
+ """
+ MERGE INTO target USING source ON target.id = source.id
+ WHEN MATCHED AND source.is_deleted = 1 THEN DELETE
+ WHEN MATCHED THEN UPDATE SET val = source.val
+ WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)
+ """,
+ write={
+ "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
+ "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
+ "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
+ },
+ )
+ self.validate_all(
+ """
+ MERGE INTO target USING source ON target.id = source.id
+ WHEN MATCHED THEN UPDATE *
+ WHEN NOT MATCHED THEN INSERT *
+ """,
+ write={
+ "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
+ },
+ )
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 22d7bce..5ac8714 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -145,6 +145,10 @@ class TestHive(Validator):
},
)
+ self.validate_identity(
+ """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""",
+ )
+
def test_lateral_view(self):
self.validate_all(
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index cd6117c..962b28b 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -256,3 +256,7 @@ class TestPostgres(Validator):
"SELECT $$Dianne's horse$$",
write={"postgres": "SELECT 'Dianne''s horse'"},
)
+ self.validate_all(
+ "UPDATE MYTABLE T1 SET T1.COL = 13",
+ write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"},
+ )
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 1943ee3..3034df5 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -56,8 +56,27 @@ class TestRedshift(Validator):
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
},
)
+ self.validate_all(
+ "DECODE(x, a, b, c, d)",
+ write={
+ "": "MATCHES(x, a, b, c, d)",
+ "oracle": "DECODE(x, a, b, c, d)",
+ "snowflake": "DECODE(x, a, b, c, d)",
+ },
+ )
+ self.validate_all(
+ "NVL(a, b, c, d)",
+ write={
+ "redshift": "COALESCE(a, b, c, d)",
+ "mysql": "COALESCE(a, b, c, d)",
+ "postgres": "COALESCE(a, b, c, d)",
+ },
+ )
def test_identity(self):
+ self.validate_identity(
+ "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')"
+ )
self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)")
@@ -70,9 +89,9 @@ class TestRedshift(Validator):
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)
- self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO")
+ self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL")
self.validate_identity(
- "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)"
+ "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
)
self.validate_identity(
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
@@ -80,3 +99,6 @@ class TestRedshift(Validator):
self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
)
+ self.validate_identity(
+ "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index baca269..bca5aaa 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -500,3 +500,12 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL F
},
pretty=True,
)
+
+ def test_minus(self):
+ self.validate_all(
+ "SELECT 1 EXCEPT SELECT 1",
+ read={
+ "oracle": "SELECT 1 MINUS SELECT 1",
+ "snowflake": "SELECT 1 MINUS SELECT 1",
+ },
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 06ab96d..e12b673 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -75,6 +75,7 @@ ARRAY(1, 2)
ARRAY_CONTAINS(x, 1)
EXTRACT(x FROM y)
EXTRACT(DATE FROM y)
+EXTRACT(WEEK(monday) FROM created_at)
CONCAT_WS('-', 'a', 'b')
CONCAT_WS('-', 'a', 'b', 'c')
POSEXPLODE("x") AS ("a", "b")
diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql
index 7fcdbb8..8880881 100644
--- a/tests/fixtures/optimizer/canonicalize.sql
+++ b/tests/fixtures/optimizer/canonicalize.sql
@@ -3,3 +3,9 @@ SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
+
+SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
+SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
+
+SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
+SELECT 1 + 3.2 AS a FROM w AS w;
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index d9c7779..cf4195d 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -79,14 +79,16 @@ NULL;
NULL = NULL;
NULL;
+-- Can't optimize this because different engines do different things
+-- mysql converts to 0 and 1 but tsql does true and false
NULL <=> NULL;
-TRUE;
+NULL IS NOT DISTINCT FROM NULL;
a IS NOT DISTINCT FROM a;
-TRUE;
+a IS NOT DISTINCT FROM a;
NULL IS DISTINCT FROM NULL;
-FALSE;
+NULL IS DISTINCT FROM NULL;
NOT (NOT TRUE);
TRUE;
@@ -239,10 +241,10 @@ TRUE;
FALSE;
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
-TRUE;
+x = x;
((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2);
-TRUE;
+x = x;
(('a' = 'a') AND TRUE and NOT FALSE);
TRUE;
@@ -372,3 +374,171 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
date '1998-12-01' + interval '90' foo;
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
+
+--------------------------------------
+-- Comparisons
+--------------------------------------
+x < 0 OR x > 1;
+x < 0 OR x > 1;
+
+x < 0 OR x > 0;
+x < 0 OR x > 0;
+
+x < 1 OR x > 0;
+x < 1 OR x > 0;
+
+x < 1 OR x >= 0;
+x < 1 OR x >= 0;
+
+x <= 1 OR x > 0;
+x <= 1 OR x > 0;
+
+x <= 1 OR x >= 0;
+x <= 1 OR x >= 0;
+
+x <= 1 AND x <= 0;
+x <= 0;
+
+x <= 1 AND x > 0;
+x <= 1 AND x > 0;
+
+x <= 1 OR x > 0;
+x <= 1 OR x > 0;
+
+x <= 0 OR x < 0;
+x <= 0;
+
+x >= 0 OR x > 0;
+x >= 0;
+
+x >= 0 OR x > 1;
+x >= 0;
+
+x <= 0 OR x >= 0;
+x <= 0 OR x >= 0;
+
+x <= 0 AND x >= 0;
+x <= 0 AND x >= 0;
+
+x < 1 AND x < 2;
+x < 1;
+
+x < 1 OR x < 2;
+x < 2;
+
+x < 2 AND x < 1;
+x < 1;
+
+x < 2 OR x < 1;
+x < 2;
+
+x < 1 AND x < 1;
+x < 1;
+
+x < 1 OR x < 1;
+x < 1;
+
+x <= 1 AND x < 1;
+x < 1;
+
+x <= 1 OR x < 1;
+x <= 1;
+
+x < 1 AND x <= 1;
+x < 1;
+
+x < 1 OR x <= 1;
+x <= 1;
+
+x > 1 AND x > 2;
+x > 2;
+
+x > 1 OR x > 2;
+x > 1;
+
+x > 2 AND x > 1;
+x > 2;
+
+x > 2 OR x > 1;
+x > 1;
+
+x > 1 AND x > 1;
+x > 1;
+
+x > 1 OR x > 1;
+x > 1;
+
+x >= 1 AND x > 1;
+x > 1;
+
+x >= 1 OR x > 1;
+x >= 1;
+
+x > 1 AND x >= 1;
+x > 1;
+
+x > 1 OR x >= 1;
+x >= 1;
+
+x > 1 AND x >= 2;
+x >= 2;
+
+x > 1 OR x >= 2;
+x > 1;
+
+x > 1 AND x >= 2 AND x > 3 AND x > 0;
+x > 3;
+
+(x > 1 AND x >= 2 AND x > 3 AND x > 0) OR x > 0;
+x > 0;
+
+x > 1 AND x < 2 AND x > 3;
+FALSE;
+
+x > 1 AND x < 1;
+FALSE;
+
+x < 2 AND x > 1;
+x < 2 AND x > 1;
+
+x = 1 AND x < 1;
+FALSE;
+
+x = 1 AND x < 1.1;
+x = 1;
+
+x = 1 AND x <= 1;
+x = 1;
+
+x = 1 AND x <= 0.9;
+FALSE;
+
+x = 1 AND x > 0.9;
+x = 1;
+
+x = 1 AND x > 1;
+FALSE;
+
+x = 1 AND x >= 1;
+x = 1;
+
+x = 1 AND x >= 2;
+FALSE;
+
+x = 1 AND x <> 2;
+x = 1;
+
+x <> 1 AND x = 1;
+FALSE;
+
+x BETWEEN 0 AND 5 AND x > 3;
+x <= 5 AND x > 3;
+
+x > 3 AND 5 > x AND x BETWEEN 0 AND 10;
+x < 5 AND x > 3;
+
+x > 3 AND 5 < x AND x BETWEEN 9 AND 10;
+x <= 10 AND x >= 9;
+
+1 < x AND 3 < x;
+x > 3;
diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
index 4893743..9c1f138 100644
--- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql
+++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
@@ -190,7 +190,7 @@ SELECT
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue",
- CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
+ "orders"."o_orderdate" AS "o_orderdate",
"orders"."o_shippriority" AS "o_shippriority"
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
@@ -326,7 +326,8 @@ SELECT
SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue"
FROM "lineitem" AS "lineitem"
WHERE
- "lineitem"."l_discount" BETWEEN 0.05 AND 0.07
+ "lineitem"."l_discount" <= 0.07
+ AND "lineitem"."l_discount" >= 0.05
AND "lineitem"."l_quantity" < 24
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
@@ -344,7 +345,7 @@ from
select
n1.n_name as supp_nation,
n2.n_name as cust_nation,
- extract(year from l_shipdate) as l_year,
+ extract(year from cast(l_shipdate as date)) as l_year,
l_extendedprice * (1 - l_discount) as volume
from
supplier,
@@ -384,13 +385,14 @@ WITH "n1" AS (
SELECT
"n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation",
- EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year",
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
- ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ ON CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1996-12-31' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-01-01' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@@ -409,7 +411,7 @@ JOIN "n1" AS "n2"
GROUP BY
"n1"."n_name",
"n2"."n_name",
- EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE))
ORDER BY
"supp_nation",
"cust_nation",
@@ -427,7 +429,7 @@ select
from
(
select
- extract(year from o_orderdate) as o_year,
+ extract(year from cast(o_orderdate as date)) as o_year,
l_extendedprice * (1 - l_discount) as volume,
n2.n_name as nation
from
@@ -456,7 +458,7 @@ group by
order by
o_year;
SELECT
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
SUM(
CASE
WHEN "nation_2"."n_name" = 'BRAZIL'
@@ -477,7 +479,8 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey"
- AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) <= CAST('1996-12-31' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1995-01-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey"
@@ -488,7 +491,7 @@ JOIN "nation" AS "nation_2"
WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
ORDER BY
"o_year";
@@ -503,7 +506,7 @@ from
(
select
n_name as nation,
- extract(year from o_orderdate) as o_year,
+ extract(year from cast(o_orderdate as date)) as o_year,
l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount
from
part,
@@ -529,7 +532,7 @@ order by
o_year desc;
SELECT
"nation"."n_name" AS "nation",
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
SUM(
"lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
@@ -551,7 +554,7 @@ WHERE
"part"."p_name" LIKE '%green%'
GROUP BY
"nation"."n_name",
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
ORDER BY
"nation",
"o_year" DESC;
@@ -1016,7 +1019,7 @@ select
o_orderkey,
o_orderdate,
o_totalprice,
- sum(l_quantity)
+ sum(l_quantity) total_quantity
from
customer,
orders,
@@ -1060,7 +1063,7 @@ SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_orderdate" AS "o_orderdate",
"orders"."o_totalprice" AS "o_totalprice",
- SUM("lineitem"."l_quantity") AS "_col_5"
+ SUM("lineitem"."l_quantity") AS "total_quantity"
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
@@ -1129,19 +1132,22 @@ JOIN "part" AS "part"
"part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 5
+ AND "part"."p_size" <= 5
+ AND "part"."p_size" >= 1
)
OR (
"part"."p_brand" = 'Brand#23'
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 10
+ AND "part"."p_size" <= 10
+ AND "part"."p_size" >= 1
)
OR (
"part"."p_brand" = 'Brand#34'
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 15
+ AND "part"."p_size" <= 15
+ AND "part"."p_size" >= 1
)
WHERE
(
@@ -1152,7 +1158,8 @@ WHERE
AND "part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 5
+ AND "part"."p_size" <= 5
+ AND "part"."p_size" >= 1
)
OR (
"lineitem"."l_quantity" <= 20
@@ -1162,7 +1169,8 @@ WHERE
AND "part"."p_brand" = 'Brand#23'
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 10
+ AND "part"."p_size" <= 10
+ AND "part"."p_size" >= 1
)
OR (
"lineitem"."l_quantity" <= 30
@@ -1172,7 +1180,8 @@ WHERE
AND "part"."p_brand" = 'Brand#34'
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 15
+ AND "part"."p_size" <= 15
+ AND "part"."p_size" >= 1
);
--------------------------------------
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 9d452e4..4fe6399 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -26,12 +26,12 @@ class TestExecutor(unittest.TestCase):
def setUpClass(cls):
cls.conn = duckdb.connect()
- for table in TPCH_SCHEMA:
+ for table, columns in TPCH_SCHEMA.items():
cls.conn.execute(
f"""
CREATE VIEW {table} AS
SELECT *
- FROM READ_CSV_AUTO('{DIR}{table}.csv.gz')
+ FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns})
"""
)
@@ -74,13 +74,13 @@ class TestExecutor(unittest.TestCase):
)
return expression
- for i, (sql, _) in enumerate(self.sqls[0:16]):
+ for i, (sql, _) in enumerate(self.sqls[0:18]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
table = execute(sql, TPCH_SCHEMA)
b = pd.DataFrame(table.rows, columns=table.columns)
- assert_frame_equal(a, b, check_dtype=False)
+ assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
def test_execute_callable(self):
tables = {
@@ -456,11 +456,16 @@ class TestExecutor(unittest.TestCase):
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
- ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
+ (
+ "SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
+ ["_col_0", "_col_1"],
+ [(None, 0)],
+ ),
]:
- result = execute(sql)
- self.assertEqual(result.columns, tuple(cols))
- self.assertEqual(result.rows, rows)
+ with self.subTest(sql):
+ result = execute(sql)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
def test_aggregate_without_group_by(self):
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index ecf581d..0c5f6cd 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Literal).type, target_type)
+ self.assertEqual(expression.find(exp.Literal).type.this, target_type)
def test_boolean_type_annotation(self):
tests = {
@@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Boolean).type, target_type)
+ self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
def test_cast_type_annotation(self):
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT)
- self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
+ expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>"))
+ self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType))
def test_cache_annotation(self):
expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
)
- self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
+ self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT)
def test_binary_annotation(self):
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
- self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
- self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
- self.assertEqual(expression.right.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
def test_derived_tables_column_annotation(self):
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
@@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = annotate_types(parse_one(sql), schema=schema)
- self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
+ self.assertEqual(
+ expression.expressions[0].type.this, exp.DataType.Type.FLOAT
+ ) # a.cola AS cola
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
- self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
+ self.assertEqual(
+ addition_alias.type.this, exp.DataType.Type.FLOAT
+ ) # x.cola + y.cola AS cola
addition = addition_alias.this
- self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
- self.assertEqual(addition.this.type, exp.DataType.Type.INT)
- self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.this.type.this, exp.DataType.Type.INT)
+ self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT)
def test_cte_column_annotation(self):
- schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
+ schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}}
sql = """
WITH tbl AS (
- SELECT x.cola + 'bla' AS cola, y.colb AS colb
+ SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc
FROM (
SELECT x.cola AS cola
FROM x AS x
) AS x
JOIN (
- SELECT y.colb AS colb
+ SELECT y.colb AS colb, y.colc AS colc
FROM y AS y
) AS y
)
SELECT tbl.cola + tbl.colb + 'foo' AS col
FROM tbl AS tbl
+ WHERE tbl.colc = True
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(
- expression.expressions[0].type, exp.DataType.Type.TEXT
+ expression.expressions[0].type.this, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
- self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
- self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
- self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR)
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
- self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
+ self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT)
+
+ # WHERE tbl.colc = True
+ self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN)
cte_select = expression.args["with"].expressions[0].this
self.assertEqual(
- cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
+ cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola
- self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
+ self.assertEqual(
+ cte_select.expressions[1].type.this, exp.DataType.Type.TEXT
+ ) # y.colb AS colb
+ self.assertEqual(
+ cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN
+ ) # y.colc AS colc
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
- self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
- self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR)
+ self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
):
- self.assertEqual(d.this.expressions[0].this.type, t)
+ self.assertEqual(d.this.expressions[0].this.type.this, t)
def test_function_annotation(self):
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
- self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
+ self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this
- self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
+ self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0]
- self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(
- concat_expr.right.type, exp.DataType.Type.UNKNOWN
+ concat_expr.right.type.this, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(
- concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
+ concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
- self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
# NULL <op> UNKNOWN should yield NULL
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
def test_nullable_annotation(self):
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression = annotate_types(parse_one("NULL AND FALSE"))
self.assertEqual(expression.type, nullable)
- self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
+
+ def test_predicate_annotation(self):
+ expression = annotate_types(parse_one("x BETWEEN a AND b"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
+
+ expression = annotate_types(parse_one("x IN (a, b, c, d)"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
+
+ def test_aggfunc_annotation(self):
+ schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}}
+
+ tests = {
+ ("AVG", "cola"): exp.DataType.Type.DOUBLE,
+ ("SUM", "cola"): exp.DataType.Type.BIGINT,
+ ("SUM", "colb"): exp.DataType.Type.DOUBLE,
+ ("MIN", "cola"): exp.DataType.Type.SMALLINT,
+ ("MIN", "colb"): exp.DataType.Type.FLOAT,
+ ("MAX", "colc"): exp.DataType.Type.TEXT,
+ ("MAX", "cold"): exp.DataType.Type.DATE,
+ ("COUNT", "colb"): exp.DataType.Type.BIGINT,
+ ("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
+ }
+
+ for (func, col), target_type in tests.items():
+ expression = annotate_types(
+ parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
+ )
+ self.assertEqual(expression.expressions[0].type.this, target_type)
diff --git a/tests/test_schema.py b/tests/test_schema.py
index cc0e3d1..f1e12a2 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -151,31 +151,33 @@ class TestSchema(unittest.TestCase):
def test_schema_get_column_type(self):
schema = MappingSchema({"a": {"b": "varchar"}})
- self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
+ self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR)
self.assertEqual(
- schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
+ schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this,
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
- schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
+ schema.get_column_type("a", exp.Column(this="b")).this, exp.DataType.Type.VARCHAR
)
self.assertEqual(
- schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
+ schema.get_column_type(exp.Table(this="a"), "b").this, exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
self.assertEqual(
- schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
+ schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")).this,
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
- schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
+ schema.get_column_type(exp.Table(this="b", db="a"), "c").this, exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
self.assertEqual(
- schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
+ schema.get_column_type(
+ exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")
+ ).this,
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
- schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
+ schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this,
exp.DataType.Type.VARCHAR,
)
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
index 1d1b966..1376849 100644
--- a/tests/test_tokens.py
+++ b/tests/test_tokens.py
@@ -1,6 +1,6 @@
import unittest
-from sqlglot.tokens import Tokenizer
+from sqlglot.tokens import Tokenizer, TokenType
class TestTokens(unittest.TestCase):
@@ -17,3 +17,48 @@ class TestTokens(unittest.TestCase):
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
+
+ def test_jinja(self):
+ tokenizer = Tokenizer()
+
+ tokens = tokenizer.tokenize(
+ """
+ SELECT
+ {{ x }},
+ {{- x -}},
+ {% for x in y -%}
+ a {{+ b }}
+ {% endfor %};
+ """
+ )
+
+ tokens = [(token.token_type, token.text) for token in tokens]
+
+ self.assertEqual(
+ tokens,
+ [
+ (TokenType.SELECT, "SELECT"),
+ (TokenType.BLOCK_START, "{{"),
+ (TokenType.VAR, "x"),
+ (TokenType.BLOCK_END, "}}"),
+ (TokenType.COMMA, ","),
+ (TokenType.BLOCK_START, "{{-"),
+ (TokenType.VAR, "x"),
+ (TokenType.BLOCK_END, "-}}"),
+ (TokenType.COMMA, ","),
+ (TokenType.BLOCK_START, "{%"),
+ (TokenType.FOR, "for"),
+ (TokenType.VAR, "x"),
+ (TokenType.IN, "in"),
+ (TokenType.VAR, "y"),
+ (TokenType.BLOCK_END, "-%}"),
+ (TokenType.VAR, "a"),
+ (TokenType.BLOCK_START, "{{+"),
+ (TokenType.VAR, "b"),
+ (TokenType.BLOCK_END, "}}"),
+ (TokenType.BLOCK_START, "{%"),
+ (TokenType.VAR, "endfor"),
+ (TokenType.BLOCK_END, "%}"),
+ (TokenType.SEMICOLON, ";"),
+ ],
+ )