diff options
Diffstat (limited to '')
36 files changed, 1281 insertions, 493 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index a439c2c..7dfca94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,43 @@ Changelog ========= +v10.2.0 +------ + +Changes: + +- Breaking: types inferred from annotate_types are now DataType objects, instead of DataType.Type. + +- New: the optimizer can now simplify [BETWEEN expressions expressed as explicit comparisons](https://github.com/tobymao/sqlglot/commit/e24d0317dfa644104ff21d009b790224bf84d698). + +- New: the optimizer now removes redundant casts. + +- New: added support for Redshift's ENCODE/DECODE. + +- New: the optimizer now [treats identifiers as case-insensitive](https://github.com/tobymao/sqlglot/commit/638ed265f195219d7226f4fbae128f1805ae8988). + +- New: the optimizer now [handles nested CTEs](https://github.com/tobymao/sqlglot/commit/1bdd652792889a8aaffb1c6d2c8aa1fe4a066281). + +- New: the executor can now execute SELECT DISTINCT expressions. + +- New: added support for Redshift's COPY and UNLOAD commands. + +- New: added ability to parse LIKE in CREATE TABLE statement. + +- New: the optimizer now [unnests scalar subqueries as cross joins](https://github.com/tobymao/sqlglot/commit/4373ad8518ede4ef1fda8b247b648c680a93d12d). + +- Improvement: fixed Bigquery's ARRAY function parsing, so that it can now handle a SELECT expression as an argument. + +- Improvement: improved Snowflake's [ARRAY and MAP constructs](https://github.com/tobymao/sqlglot/commit/0506657dba55fe71d004c81c907e23cdd2b37d82). + +- Improvement: fixed transpilation between STRING_AGG and GROUP_CONCAT. + +- Improvement: the INTO clause can now be parsed in SELECT expressions. + +- Improvement: improve executor; it currently executes all TPC-H queries up to TPC-H 17 (inclusive). + +- Improvement: DISTINCT ON is now transpiled to a SELECT expression from a subquery for Redshift. + v10.1.0 ------ diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index b027ac7..3733b20 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -30,7 +30,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.1.3" +__version__ = "10.2.6" pretty = False diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 548c322..3c45741 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -317,7 +317,7 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.name + expression.alias_or_name: expression.type.sql("spark") for expression in select_expression.expressions }, ) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5b44912..6be68ac 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -110,17 +110,17 @@ class BigQuery(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, "GEOGRAPHY": TokenType.GEOGRAPHY, - "INT64": TokenType.BIGINT, "FLOAT64": TokenType.DOUBLE, + "INT64": TokenType.BIGINT, + "NOT DETERMINISTIC": TokenType.VOLATILE, "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, "WINDOW": TokenType.WINDOW, - "NOT DETERMINISTIC": TokenType.VOLATILE, - "BEGIN": TokenType.COMMAND, - "BEGIN TRANSACTION": TokenType.BEGIN, } KEYWORDS.pop("DIV") @@ -131,6 +131,7 @@ class BigQuery(Dialect): "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), + "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), @@ -144,6 +145,7 @@ class BigQuery(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), } FUNCTION_PARSERS.pop("TRIM") @@ -161,7 +163,6 @@ class BigQuery(Dialect): class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, - exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), @@ -183,6 +184,7 @@ class BigQuery(Dialect): exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", + exp.RegexpLike: rename_func("REGEXP_CONTAINS"), } TYPE_MAPPING = { @@ -210,24 +212,31 @@ class BigQuery(Dialect): EXPLICIT_UNION = True - def transaction_sql(self, *_): + def array_sql(self, expression: exp.Array) -> str: + first_arg = seq_get(expression.expressions, 0) + if isinstance(first_arg, exp.Subqueryable): + return f"ARRAY{self.wrap(self.sql(first_arg))}" + + return inline_array_sql(self, expression) + + def transaction_sql(self, *_) -> str: return "BEGIN TRANSACTION" - def commit_sql(self, *_): + def commit_sql(self, *_) -> str: return "COMMIT TRANSACTION" - def rollback_sql(self, *_): + def rollback_sql(self, *_) -> str: return "ROLLBACK TRANSACTION" - def in_unnest_op(self, unnest): - return self.sql(unnest) + def in_unnest_op(self, expression: exp.Unnest) -> str: + return self.sql(expression) - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: if not expression.args.get("distinct", False): self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index cbb39c2..70c1c6c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -190,6 +190,7 @@ class Hive(Dialect): "ADD FILES": TokenType.COMMAND, "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, + "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } class Parser(parser.Parser): @@ -238,6 +239,13 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties( + expressions=self._parse_wrapped_csv(self._parse_property) + ), + } + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -297,6 +305,8 @@ class Hive(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", + exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", + exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), } @@ -308,12 +318,15 @@ class Hive(Dialect): exp.SchemaCommentProperty, exp.LocationProperty, exp.TableFormatProperty, + exp.RowFormatDelimitedProperty, + exp.RowFormatSerdeProperty, + exp.SerdeProperties, } def with_properties(self, properties): return self.properties( properties, - prefix="TBLPROPERTIES", + prefix=self.seg("TBLPROPERTIES"), ) def datatype_sql(self, expression): diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index ceaf9ba..f507513 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -98,6 +98,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "MINUS": TokenType.EXCEPT, "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index cd50979..55ed0a6 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,6 +1,7 @@ from __future__ import annotations from sqlglot import exp, transforms +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -13,12 +14,20 @@ class Redshift(Postgres): "HH": "%H", } + class Parser(Postgres.Parser): + FUNCTIONS = { + **Postgres.Parser.FUNCTIONS, # type: ignore + "DECODE": exp.Matches.from_arg_list, + "NVL": exp.Coalesce.from_arg_list, + } + class Tokenizer(Postgres.Tokenizer): ESCAPES = ["\\"] KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore "COPY": TokenType.COMMAND, + "ENCODE": TokenType.ENCODE, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, @@ -50,4 +59,5 @@ class Redshift(Postgres): exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), + exp.Matches: rename_func("DECODE"), } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 46155ff..75dc9dc 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -198,6 +198,7 @@ class Snowflake(Dialect): "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, + "MINUS": TokenType.EXCEPT, "SAMPLE": TokenType.TABLE_SAMPLE, } diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index e6cfcdd..ad9397e 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -19,10 +19,13 @@ class reverse_key: return other.obj < self.obj -def filter_nulls(func): +def filter_nulls(func, empty_null=True): @wraps(func) def _func(values): - return func(v for v in values if v is not None) + filtered = tuple(v for v in values if v is not None) + if not filtered and empty_null: + return None + return func(filtered) return _func @@ -126,7 +129,7 @@ ENV = { # aggs "SUM": filter_nulls(sum), "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore - "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)), + "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), "MAX": filter_nulls(max), "MIN": filter_nulls(min), # scalar functions diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 908b80a..9f22c45 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -310,9 +310,9 @@ class PythonExecutor: if i == length - 1: context.set_range(start, end - 1) add_row() - elif step.limit > 0: + elif step.limit > 0 and not group_by: context.set_range(0, 0) - table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations)) + table.append(context.eval_tuple(aggregations)) context = self.context({step.name: table, **{name: table for name in context.tables}}) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 96b32f1..7249574 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comments") + __slots__ = ("args", "parent", "arg_key", "comments", "_type") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None - self.type = None self.comments = None + self._type: t.Optional[DataType] = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -122,6 +122,16 @@ class Expression(metaclass=_Expression): return "NULL" return self.alias or self.name + @property + def type(self) -> t.Optional[DataType]: + return self._type + + @type.setter + def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None: + if dtype and not isinstance(dtype, DataType): + dtype = DataType.build(dtype) + self._type = dtype # type: ignore + def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) copy.comments = self.comments @@ -348,7 +358,7 @@ class Expression(metaclass=_Expression): indent += "".join([" "] * level) left = f"({self.key.upper()} " - args = { + args: t.Dict[str, t.Any] = { k: ", ".join( v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) for v in ensure_collection(vs) @@ -612,6 +622,7 @@ class Create(Expression): "properties": False, "temporary": False, "transient": False, + "external": False, "replace": False, "unique": False, "materialized": False, @@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind): pass +class EncodeColumnConstraint(ColumnConstraintKind): + pass + + class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = {"this": True, "expression": False} class NotNullColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"allow_null": False} class PrimaryKeyColumnConstraint(ColumnConstraintKind): @@ -766,7 +781,7 @@ class Constraint(Expression): class Delete(Expression): - arg_types = {"with": False, "this": True, "using": False, "where": False} + arg_types = {"with": False, "this": False, "using": False, "where": False} class Drop(Expression): @@ -850,7 +865,7 @@ class Insert(Expression): arg_types = { "with": False, "this": True, - "expression": True, + "expression": False, "overwrite": False, "exists": False, "partition": False, @@ -1125,6 +1140,27 @@ class VolatilityProperty(Property): arg_types = {"this": True} +class RowFormatDelimitedProperty(Property): + # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml + arg_types = { + "fields": False, + "escaped": False, + "collection_items": False, + "map_keys": False, + "lines": False, + "null": False, + "serde": False, + } + + +class RowFormatSerdeProperty(Property): + arg_types = {"this": True} + + +class SerdeProperties(Property): + arg_types = {"expressions": True} + + class Properties(Expression): arg_types = {"expressions": True} @@ -1169,18 +1205,6 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} -class RowFormat(Expression): - # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml - arg_types = { - "fields": False, - "escaped": False, - "collection_items": False, - "map_keys": False, - "lines": False, - "null": False, - } - - class Tuple(Expression): arg_types = {"expressions": False} @@ -1208,6 +1232,9 @@ class Subqueryable(Unionable): alias=TableAlias(this=to_identifier(alias)), ) + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + raise NotImplementedError + @property def ctes(self): with_ = self.args.get("with") @@ -1320,6 +1347,32 @@ class Union(Subqueryable): **QUERY_MODIFIERS, } + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + """ + Set the LIMIT expression. + + Example: + >>> select("1").union(select("1")).limit(1).sql() + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' + + Args: + expression (str | int | Expression): the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: The limited subqueryable. + """ + return ( + select("*") + .from_(self.subquery(alias="_l_0", copy=copy)) + .limit(expression, dialect=dialect, copy=False, **opts) + ) + @property def named_selects(self): return self.this.unnest().named_selects @@ -1356,7 +1409,7 @@ class Unnest(UDTF): class Update(Expression): arg_types = { "with": False, - "this": True, + "this": False, "expressions": True, "from": False, "where": False, @@ -2057,15 +2110,20 @@ class DataType(Expression): Type.TEXT, } - NUMERIC_TYPES = { + INTEGER_TYPES = { Type.INT, Type.TINYINT, Type.SMALLINT, Type.BIGINT, + } + + FLOAT_TYPES = { Type.FLOAT, Type.DOUBLE, } + NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES} + TEMPORAL_TYPES = { Type.TIMESTAMP, Type.TIMESTAMPTZ, @@ -2968,6 +3026,14 @@ class Use(Expression): pass +class Merge(Expression): + arg_types = {"this": True, "using": True, "on": True, "expressions": True} + + +class When(Func): + arg_types = {"this": True, "then": True} + + def _norm_args(expression): args = {} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 47774fc..beffb91 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -189,12 +189,12 @@ class Generator: self._max_text_width = max_text_width self._comments = comments - def generate(self, expression): + def generate(self, expression: t.Optional[exp.Expression]) -> str: """ Generates a SQL string by interpreting the given syntax tree. Args - expression (Expression): the syntax tree. + expression: the syntax tree. Returns the SQL string. @@ -213,23 +213,23 @@ class Generator: return sql - def unsupported(self, message): + def unsupported(self, message: str) -> None: if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) - def sep(self, sep=" "): + def sep(self, sep: str = " ") -> str: return f"{sep.strip()}\n" if self.pretty else sep - def seg(self, sql, sep=" "): + def seg(self, sql: str, sep: str = " ") -> str: return f"{self.sep(sep)}{sql}" - def pad_comment(self, comment): + def pad_comment(self, comment: str) -> str: comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment return comment - def maybe_comment(self, sql, expression): + def maybe_comment(self, sql: str, expression: exp.Expression) -> str: comments = expression.comments if self._comments else None if not comments: @@ -243,7 +243,7 @@ class Generator: return f"{sql} {comments}" - def wrap(self, expression): + def wrap(self, expression: exp.Expression | str) -> str: this_sql = self.indent( self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) @@ -253,21 +253,28 @@ class Generator: ) return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - def no_identify(self, func): + def no_identify(self, func: t.Callable[[], str]) -> str: original = self.identify self.identify = False result = func() self.identify = original return result - def normalize_func(self, name): + def normalize_func(self, name: str) -> str: if self.normalize_functions == "upper": return name.upper() if self.normalize_functions == "lower": return name.lower() return name - def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False): + def indent( + self, + sql: str, + level: int = 0, + pad: t.Optional[int] = None, + skip_first: bool = False, + skip_last: bool = False, + ) -> str: if not self.pretty: return sql @@ -281,7 +288,12 @@ class Generator: for i, line in enumerate(lines) ) - def sql(self, expression, key=None, comment=True): + def sql( + self, + expression: t.Optional[str | exp.Expression], + key: t.Optional[str] = None, + comment: bool = True, + ) -> str: if not expression: return "" @@ -313,12 +325,12 @@ class Generator: return self.maybe_comment(sql, expression) if self._comments and comment else sql - def uncache_sql(self, expression): + def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") exists_sql = " IF EXISTS" if expression.args.get("exists") else "" return f"UNCACHE TABLE{exists_sql} {table}" - def cache_sql(self, expression): + def cache_sql(self, expression: exp.Cache) -> str: lazy = " LAZY" if expression.args.get("lazy") else "" table = self.sql(expression, "this") options = expression.args.get("options") @@ -328,13 +340,13 @@ class Generator: sql = f"CACHE{lazy} TABLE {table}{options}{sql}" return self.prepend_ctes(expression, sql) - def characterset_sql(self, expression): + def characterset_sql(self, expression: exp.CharacterSet) -> str: if isinstance(expression.parent, exp.Cast): return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" default = "DEFAULT " if expression.args.get("default") else "" return f"{default}CHARACTER SET={self.sql(expression, 'this')}" - def column_sql(self, expression): + def column_sql(self, expression: exp.Column) -> str: return ".".join( part for part in [ @@ -345,7 +357,7 @@ class Generator: if part ) - def columndef_sql(self, expression): + def columndef_sql(self, expression: exp.ColumnDef) -> str: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) @@ -354,46 +366,52 @@ class Generator: return f"{column} {kind}" return f"{column} {kind} {constraints}" - def columnconstraint_sql(self, expression): + def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") kind_sql = self.sql(expression, "kind") return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql - def autoincrementcolumnconstraint_sql(self, _): + def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) - def checkcolumnconstraint_sql(self, expression): + def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: this = self.sql(expression, "this") return f"CHECK ({this})" - def commentcolumnconstraint_sql(self, expression): + def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str: comment = self.sql(expression, "this") return f"COMMENT {comment}" - def collatecolumnconstraint_sql(self, expression): + def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str: collate = self.sql(expression, "this") return f"COLLATE {collate}" - def defaultcolumnconstraint_sql(self, expression): + def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str: + encode = self.sql(expression, "this") + return f"ENCODE {encode}" + + def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str: default = self.sql(expression, "this") return f"DEFAULT {default}" - def generatedasidentitycolumnconstraint_sql(self, expression): + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" - def notnullcolumnconstraint_sql(self, _): - return "NOT NULL" + def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: + return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" - def primarykeycolumnconstraint_sql(self, expression): + def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str: desc = expression.args.get("desc") if desc is not None: return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" return f"PRIMARY KEY" - def uniquecolumnconstraint_sql(self, _): + def uniquecolumnconstraint_sql(self, _) -> str: return "UNIQUE" - def create_sql(self, expression): + def create_sql(self, expression: exp.Create) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() expression_sql = self.sql(expression, "expression") @@ -402,47 +420,58 @@ class Generator: transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" ) + external = " EXTERNAL" if expression.args.get("external") else "" replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") - expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" + modifiers = "".join( + ( + replace, + temporary, + transient, + external, + unique, + materialized, + ) + ) + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}" return self.prepend_ctes(expression, expression_sql) - def describe_sql(self, expression): + def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" - def prepend_ctes(self, expression, sql): + def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: with_ = self.sql(expression, "with") if with_: sql = f"{with_}{self.sep()}{sql}" return sql - def with_sql(self, expression): + def with_sql(self, expression: exp.With) -> str: sql = self.expressions(expression, flat=True) recursive = "RECURSIVE " if expression.args.get("recursive") else "" return f"WITH {recursive}{sql}" - def cte_sql(self, expression): + def cte_sql(self, expression: exp.CTE) -> str: alias = self.sql(expression, "alias") return f"{alias} AS {self.wrap(expression)}" - def tablealias_sql(self, expression): + def tablealias_sql(self, expression: exp.TableAlias) -> str: alias = self.sql(expression, "this") columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" return f"{alias}{columns}" - def bitstring_sql(self, expression): + def bitstring_sql(self, expression: exp.BitString) -> str: return self.sql(expression, "this") - def hexstring_sql(self, expression): + def hexstring_sql(self, expression: exp.HexString) -> str: return self.sql(expression, "this") - def datatype_sql(self, expression): + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) nested = "" @@ -455,13 +484,13 @@ class Generator: ) return f"{type_sql}{nested}" - def directory_sql(self, expression): + def directory_sql(self, expression: exp.Directory) -> str: local = "LOCAL " if expression.args.get("local") else "" row_format = self.sql(expression, "row_format") row_format = f" {row_format}" if row_format else "" return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" - def delete_sql(self, expression): + def delete_sql(self, expression: exp.Delete) -> str: this = self.sql(expression, "this") using_sql = ( f" USING {self.expressions(expression, 'using', sep=', USING ')}" @@ -472,7 +501,7 @@ class Generator: sql = f"DELETE FROM {this}{using_sql}{where_sql}" return self.prepend_ctes(expression, sql) - def drop_sql(self, expression): + def drop_sql(self, expression: exp.Drop) -> str: this = self.sql(expression, "this") kind = expression.args["kind"] exists_sql = " IF EXISTS " if expression.args.get("exists") else " " @@ -481,46 +510,46 @@ class Generator: cascade = " CASCADE" if expression.args.get("cascade") else "" return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}" - def except_sql(self, expression): + def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.except_op(expression)), ) - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}" - def fetch_sql(self, expression): + def fetch_sql(self, expression: exp.Fetch) -> str: direction = expression.args.get("direction") direction = f" {direction.upper()}" if direction else "" count = expression.args.get("count") count = f" {count}" if count else "" return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" - def filter_sql(self, expression): + def filter_sql(self, expression: exp.Filter) -> str: this = self.sql(expression, "this") where = self.sql(expression, "expression")[1:] # where has a leading space return f"{this} FILTER({where})" - def hint_sql(self, expression): + def hint_sql(self, expression: exp.Hint) -> str: if self.sql(expression, "this"): self.unsupported("Hints are not supported") return "" - def index_sql(self, expression): + def index_sql(self, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") columns = self.sql(expression, "columns") return f"{this} ON {table} {columns}" - def identifier_sql(self, expression): + def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name text = text.lower() if self.normalize else text if expression.args.get("quoted") or self.identify: text = f"{self.identifier_start}{text}{self.identifier_end}" return text - def partition_sql(self, expression): + def partition_sql(self, expression: exp.Partition) -> str: keys = csv( *[ f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name @@ -529,7 +558,7 @@ class Generator: ) return f"PARTITION({keys})" - def properties_sql(self, expression): + def properties_sql(self, expression: exp.Properties) -> str: root_properties = [] with_properties = [] @@ -544,21 +573,21 @@ class Generator: exp.Properties(expressions=root_properties) ) + self.with_properties(exp.Properties(expressions=with_properties)) - def root_properties(self, properties): + def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" - def properties(self, properties, prefix="", sep=", "): + def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" + return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}" return "" - def with_properties(self, properties): - return self.properties(properties, prefix="WITH") + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, prefix=self.seg("WITH")) - def property_sql(self, expression): + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: return f"{expression.name}={self.sql(expression, 'value')}" @@ -569,12 +598,12 @@ class Generator: return f"{property_name}={self.sql(expression, 'this')}" - def likeproperty_sql(self, expression): + def likeproperty_sql(self, expression: exp.LikeProperty) -> str: options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) options = f" {options}" if options else "" return f"LIKE {self.sql(expression, 'this')}{options}" - def insert_sql(self, expression): + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") if isinstance(expression.this, exp.Directory): @@ -592,19 +621,19 @@ class Generator: sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) - def intersect_sql(self, expression): + def intersect_sql(self, expression: exp.Intersect) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.intersect_op(expression)), ) - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}" - def introducer_sql(self, expression): + def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def rowformat_sql(self, expression): + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" escaped = expression.args.get("escaped") @@ -619,7 +648,7 @@ class Generator: null = f" NULL DEFINED AS {null}" if null else "" return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" - def table_sql(self, expression, sep=" AS "): + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: table = ".".join( part for part in [ @@ -642,7 +671,7 @@ class Generator: return f"{table}{alias}{laterals}{joins}{pivots}" - def tablesample_sql(self, expression): + def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: this = self.sql(expression.this, "this") alias = f" AS {self.sql(expression.this, 'alias')}" @@ -665,7 +694,7 @@ class Generator: seed = f" SEED ({seed})" if seed else "" return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}" - def pivot_sql(self, expression): + def pivot_sql(self, expression: exp.Pivot) -> str: this = self.sql(expression, "this") unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" @@ -673,10 +702,10 @@ class Generator: field = self.sql(expression, "field") return f"{this} {direction}({expressions} FOR {field})" - def tuple_sql(self, expression): + def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" - def update_sql(self, expression): + def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") set_sql = self.expressions(expression, flat=True) from_sql = self.sql(expression, "from") @@ -684,7 +713,7 @@ class Generator: sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}" return self.prepend_ctes(expression, sql) - def values_sql(self, expression): + def values_sql(self, expression: exp.Values) -> str: alias = self.sql(expression, "alias") args = self.expressions(expression) if not alias: @@ -694,19 +723,19 @@ class Generator: return f"(VALUES{self.seg('')}{args}){alias}" return f"VALUES{self.seg('')}{args}{alias}" - def var_sql(self, expression): + def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") - def into_sql(self, expression): + def into_sql(self, expression: exp.Into) -> str: temporary = " TEMPORARY" if expression.args.get("temporary") else "" unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" - def from_sql(self, expression): + def from_sql(self, expression: exp.From) -> str: expressions = self.expressions(expression, flat=True) return f"{self.seg('FROM')} {expressions}" - def group_sql(self, expression): + def group_sql(self, expression: exp.Group) -> str: group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = ( @@ -718,11 +747,11 @@ class Generator: rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" return f"{group_by}{grouping_sets}{cube}{rollup}" - def having_sql(self, expression): + def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('HAVING')}{self.sep()}{this}" - def join_sql(self, expression): + def join_sql(self, expression: exp.Join) -> str: op_sql = self.seg( " ".join( op @@ -753,12 +782,12 @@ class Generator: this_sql = self.sql(expression, "this") return f"{expression_sql}{op_sql} {this_sql}{on_sql}" - def lambda_sql(self, expression, arrow_sep="->"): + def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") - def lateral_sql(self, expression): + def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") if isinstance(expression.this, exp.Subquery): @@ -776,15 +805,15 @@ class Generator: return f"LATERAL {this}{table}{columns}" - def limit_sql(self, expression): + def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" - def offset_sql(self, expression): + def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" - def literal_sql(self, expression): + def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: if self._replace_backslash: @@ -793,7 +822,7 @@ class Generator: text = f"{self.quote_start}{text}{self.quote_end}" return text - def loaddata_sql(self, expression): + def loaddata_sql(self, expression: exp.LoadData) -> str: local = " LOCAL" if expression.args.get("local") else "" inpath = f" INPATH {self.sql(expression, 'inpath')}" overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" @@ -806,27 +835,27 @@ class Generator: serde = f" SERDE {serde}" if serde else "" return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" - def null_sql(self, *_): + def null_sql(self, *_) -> str: return "NULL" - def boolean_sql(self, expression): + def boolean_sql(self, expression: exp.Boolean) -> str: return "TRUE" if expression.this else "FALSE" - def order_sql(self, expression, flat=False): + def order_sql(self, expression: exp.Order, flat: bool = False) -> str: this = self.sql(expression, "this") this = f"{this} " if this else this - return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) + return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore - def cluster_sql(self, expression): + def cluster_sql(self, expression: exp.Cluster) -> str: return self.op_expressions("CLUSTER BY", expression) - def distribute_sql(self, expression): + def distribute_sql(self, expression: exp.Distribute) -> str: return self.op_expressions("DISTRIBUTE BY", expression) - def sort_sql(self, expression): + def sort_sql(self, expression: exp.Sort) -> str: return self.op_expressions("SORT BY", expression) - def ordered_sql(self, expression): + def ordered_sql(self, expression: exp.Ordered) -> str: desc = expression.args.get("desc") asc = not desc @@ -857,7 +886,7 @@ class Generator: return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" - def query_modifiers(self, expression, *sqls): + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("laterals", [])], @@ -876,7 +905,7 @@ class Generator: sep="", ) - def select_sql(self, expression): + def select_sql(self, expression: exp.Select) -> str: hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" @@ -890,36 +919,36 @@ class Generator: ) return self.prepend_ctes(expression, sql) - def schema_sql(self, expression): + def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") this = f"{this} " if this else "" sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" return f"{this}{sql}" - def star_sql(self, expression): + def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" return f"*{except_}{replace}" - def structkwarg_sql(self, expression): + def structkwarg_sql(self, expression: exp.StructKwarg) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def parameter_sql(self, expression): + def parameter_sql(self, expression: exp.Parameter) -> str: return f"@{self.sql(expression, 'this')}" - def sessionparameter_sql(self, expression): + def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") kind = expression.text("kind") if kind: kind = f"{kind}." return f"@@{kind}{this}" - def placeholder_sql(self, expression): + def placeholder_sql(self, expression: exp.Placeholder) -> str: return f":{expression.name}" if expression.name else "?" - def subquery_sql(self, expression): + def subquery_sql(self, expression: exp.Subquery) -> str: alias = self.sql(expression, "alias") sql = self.query_modifiers( @@ -931,22 +960,22 @@ class Generator: return self.prepend_ctes(expression, sql) - def qualify_sql(self, expression): + def qualify_sql(self, expression: exp.Qualify) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('QUALIFY')}{self.sep()}{this}" - def union_sql(self, expression): + def union_sql(self, expression: exp.Union) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.union_op(expression)), ) - def union_op(self, expression): + def union_op(self, expression: exp.Union) -> str: kind = " DISTINCT" if self.EXPLICIT_UNION else "" kind = kind if expression.args.get("distinct") else " ALL" return f"UNION{kind}" - def unnest_sql(self, expression): + def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) alias = expression.args.get("alias") if alias and self.unnest_column_only: @@ -958,11 +987,11 @@ class Generator: ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" return f"UNNEST({args}){ordinality}{alias}" - def where_sql(self, expression): + def where_sql(self, expression: exp.Where) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('WHERE')}{self.sep()}{this}" - def window_sql(self, expression): + def window_sql(self, expression: exp.Window) -> str: this = self.sql(expression, "this") partition = self.expressions(expression, key="partition_by", flat=True) @@ -988,7 +1017,7 @@ class Generator: return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" - def window_spec_sql(self, expression): + def window_spec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") end = ( @@ -997,33 +1026,33 @@ class Generator: ) return f"{kind} BETWEEN {start} AND {end}" - def withingroup_sql(self, expression): + def withingroup_sql(self, expression: exp.WithinGroup) -> str: this = self.sql(expression, "this") - expression = self.sql(expression, "expression")[1:] # order has a leading space - return f"{this} WITHIN GROUP ({expression})" + expression_sql = self.sql(expression, "expression")[1:] # order has a leading space + return f"{this} WITHIN GROUP ({expression_sql})" - def between_sql(self, expression): + def between_sql(self, expression: exp.Between) -> str: this = self.sql(expression, "this") low = self.sql(expression, "low") high = self.sql(expression, "high") return f"{this} BETWEEN {low} AND {high}" - def bracket_sql(self, expression): + def bracket_sql(self, expression: exp.Bracket) -> str: expressions = apply_index_offset(expression.expressions, self.index_offset) - expressions = ", ".join(self.sql(e) for e in expressions) + expressions_sql = ", ".join(self.sql(e) for e in expressions) - return f"{self.sql(expression, 'this')}[{expressions}]" + return f"{self.sql(expression, 'this')}[{expressions_sql}]" - def all_sql(self, expression): + def all_sql(self, expression: exp.All) -> str: return f"ALL {self.wrap(expression)}" - def any_sql(self, expression): + def any_sql(self, expression: exp.Any) -> str: return f"ANY {self.wrap(expression)}" - def exists_sql(self, expression): + def exists_sql(self, expression: exp.Exists) -> str: return f"EXISTS{self.wrap(expression)}" - def case_sql(self, expression): + def case_sql(self, expression: exp.Case) -> str: this = self.sql(expression, "this") statements = [f"CASE {this}" if this else "CASE"] @@ -1043,17 +1072,17 @@ class Generator: return " ".join(statements) - def constraint_sql(self, expression): + def constraint_sql(self, expression: exp.Constraint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"CONSTRAINT {this} {expressions}" - def extract_sql(self, expression): + def extract_sql(self, expression: exp.Extract) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" - def trim_sql(self, expression): + def trim_sql(self, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") @@ -1064,16 +1093,16 @@ class Generator: else: return f"TRIM({target})" - def concat_sql(self, expression): + def concat_sql(self, expression: exp.Concat) -> str: if len(expression.expressions) == 1: return self.sql(expression.expressions[0]) return self.function_fallback_sql(expression) - def check_sql(self, expression): + def check_sql(self, expression: exp.Check) -> str: this = self.sql(expression, key="this") return f"CHECK ({this})" - def foreignkey_sql(self, expression): + def foreignkey_sql(self, expression: exp.ForeignKey) -> str: expressions = self.expressions(expression, flat=True) reference = self.sql(expression, "reference") reference = f" {reference}" if reference else "" @@ -1083,16 +1112,16 @@ class Generator: update = f" ON UPDATE {update}" if update else "" return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" - def unique_sql(self, expression): + def unique_sql(self, expression: exp.Unique) -> str: columns = self.expressions(expression, key="expressions") return f"UNIQUE ({columns})" - def if_sql(self, expression): + def if_sql(self, expression: exp.If) -> str: return self.case_sql( exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) ) - def in_sql(self, expression): + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") field = expression.args.get("field") @@ -1106,24 +1135,24 @@ class Generator: in_sql = f"({self.expressions(expression, flat=True)})" return f"{self.sql(expression, 'this')} IN {in_sql}" - def in_unnest_op(self, unnest): + def in_unnest_op(self, unnest: exp.Unnest) -> str: return f"(SELECT {self.sql(unnest)})" - def interval_sql(self, expression): + def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" return f"INTERVAL {self.sql(expression, 'this')}{unit}" - def reference_sql(self, expression): + def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"REFERENCES {this}({expressions})" - def anonymous_sql(self, expression): + def anonymous_sql(self, expression: exp.Anonymous) -> str: args = self.format_args(*expression.expressions) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" - def paren_sql(self, expression): + def paren_sql(self, expression: exp.Paren) -> str: if isinstance(expression.unnest(), exp.Select): sql = self.wrap(expression) else: @@ -1132,35 +1161,35 @@ class Generator: return self.prepend_ctes(expression, sql) - def neg_sql(self, expression): + def neg_sql(self, expression: exp.Neg) -> str: # This makes sure we don't convert "- - 5" to "--5", which is a comment this_sql = self.sql(expression, "this") sep = " " if this_sql[0] == "-" else "" return f"-{sep}{this_sql}" - def not_sql(self, expression): + def not_sql(self, expression: exp.Not) -> str: return f"NOT {self.sql(expression, 'this')}" - def alias_sql(self, expression): + def alias_sql(self, expression: exp.Alias) -> str: to_sql = self.sql(expression, "alias") to_sql = f" AS {to_sql}" if to_sql else "" return f"{self.sql(expression, 'this')}{to_sql}" - def aliases_sql(self, expression): + def aliases_sql(self, expression: exp.Aliases) -> str: return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" - def attimezone_sql(self, expression): + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: this = self.sql(expression, "this") zone = self.sql(expression, "zone") return f"{this} AT TIME ZONE {zone}" - def add_sql(self, expression): + def add_sql(self, expression: exp.Add) -> str: return self.binary(expression, "+") - def and_sql(self, expression): + def and_sql(self, expression: exp.And) -> str: return self.connector_sql(expression, "AND") - def connector_sql(self, expression, op): + def connector_sql(self, expression: exp.Connector, op: str) -> str: if not self.pretty: return self.binary(expression, op) @@ -1168,53 +1197,53 @@ class Generator: sep = "\n" if self.text_width(sqls) > self._max_text_width else " " return f"{sep}{op} ".join(sqls) - def bitwiseand_sql(self, expression): + def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: return self.binary(expression, "&") - def bitwiseleftshift_sql(self, expression): + def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str: return self.binary(expression, "<<") - def bitwisenot_sql(self, expression): + def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str: return f"~{self.sql(expression, 'this')}" - def bitwiseor_sql(self, expression): + def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str: return self.binary(expression, "|") - def bitwiserightshift_sql(self, expression): + def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str: return self.binary(expression, ">>") - def bitwisexor_sql(self, expression): + def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str: return self.binary(expression, "^") - def cast_sql(self, expression): + def cast_sql(self, expression: exp.Cast) -> str: return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - def currentdate_sql(self, expression): + def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" - def collate_sql(self, expression): + def collate_sql(self, expression: exp.Collate) -> str: return self.binary(expression, "COLLATE") - def command_sql(self, expression): + def command_sql(self, expression: exp.Command) -> str: return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" - def transaction_sql(self, *_): + def transaction_sql(self, *_) -> str: return "BEGIN" - def commit_sql(self, expression): + def commit_sql(self, expression: exp.Commit) -> str: chain = expression.args.get("chain") if chain is not None: chain = " AND CHAIN" if chain else " AND NO CHAIN" return f"COMMIT{chain or ''}" - def rollback_sql(self, expression): + def rollback_sql(self, expression: exp.Rollback) -> str: savepoint = expression.args.get("savepoint") savepoint = f" TO {savepoint}" if savepoint else "" return f"ROLLBACK{savepoint}" - def distinct_sql(self, expression): + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1222,13 +1251,13 @@ class Generator: on = f" ON {on}" if on else "" return f"DISTINCT{this}{on}" - def ignorenulls_sql(self, expression): + def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: return f"{self.sql(expression, 'this')} IGNORE NULLS" - def respectnulls_sql(self, expression): + def respectnulls_sql(self, expression: exp.RespectNulls) -> str: return f"{self.sql(expression, 'this')} RESPECT NULLS" - def intdiv_sql(self, expression): + def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( this=exp.Div(this=expression.this, expression=expression.expression), @@ -1236,79 +1265,79 @@ class Generator: ) ) - def dpipe_sql(self, expression): + def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") - def div_sql(self, expression): + def div_sql(self, expression: exp.Div) -> str: return self.binary(expression, "/") - def distance_sql(self, expression): + def distance_sql(self, expression: exp.Distance) -> str: return self.binary(expression, "<->") - def dot_sql(self, expression): + def dot_sql(self, expression: exp.Dot) -> str: return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" - def eq_sql(self, expression): + def eq_sql(self, expression: exp.EQ) -> str: return self.binary(expression, "=") - def escape_sql(self, expression): + def escape_sql(self, expression: exp.Escape) -> str: return self.binary(expression, "ESCAPE") - def gt_sql(self, expression): + def gt_sql(self, expression: exp.GT) -> str: return self.binary(expression, ">") - def gte_sql(self, expression): + def gte_sql(self, expression: exp.GTE) -> str: return self.binary(expression, ">=") - def ilike_sql(self, expression): + def ilike_sql(self, expression: exp.ILike) -> str: return self.binary(expression, "ILIKE") - def is_sql(self, expression): + def is_sql(self, expression: exp.Is) -> str: return self.binary(expression, "IS") - def like_sql(self, expression): + def like_sql(self, expression: exp.Like) -> str: return self.binary(expression, "LIKE") - def similarto_sql(self, expression): + def similarto_sql(self, expression: exp.SimilarTo) -> str: return self.binary(expression, "SIMILAR TO") - def lt_sql(self, expression): + def lt_sql(self, expression: exp.LT) -> str: return self.binary(expression, "<") - def lte_sql(self, expression): + def lte_sql(self, expression: exp.LTE) -> str: return self.binary(expression, "<=") - def mod_sql(self, expression): + def mod_sql(self, expression: exp.Mod) -> str: return self.binary(expression, "%") - def mul_sql(self, expression): + def mul_sql(self, expression: exp.Mul) -> str: return self.binary(expression, "*") - def neq_sql(self, expression): + def neq_sql(self, expression: exp.NEQ) -> str: return self.binary(expression, "<>") - def nullsafeeq_sql(self, expression): + def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str: return self.binary(expression, "IS NOT DISTINCT FROM") - def nullsafeneq_sql(self, expression): + def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str: return self.binary(expression, "IS DISTINCT FROM") - def or_sql(self, expression): + def or_sql(self, expression: exp.Or) -> str: return self.connector_sql(expression, "OR") - def sub_sql(self, expression): + def sub_sql(self, expression: exp.Sub) -> str: return self.binary(expression, "-") - def trycast_sql(self, expression): + def trycast_sql(self, expression: exp.TryCast) -> str: return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - def use_sql(self, expression): + def use_sql(self, expression: exp.Use) -> str: return f"USE {self.sql(expression, 'this')}" - def binary(self, expression, op): + def binary(self, expression: exp.Binary, op: str) -> str: return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" - def function_fallback_sql(self, expression): + def function_fallback_sql(self, expression: exp.Func) -> str: args = [] for arg_value in expression.args.values(): if isinstance(arg_value, list): @@ -1319,19 +1348,26 @@ class Generator: return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" - def format_args(self, *args): - args = tuple(self.sql(arg) for arg in args if arg is not None) - if self.pretty and self.text_width(args) > self._max_text_width: - return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True) - return ", ".join(args) + def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: + arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) + if self.pretty and self.text_width(arg_sqls) > self._max_text_width: + return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) + return ", ".join(arg_sqls) - def text_width(self, args): + def text_width(self, args: t.Iterable) -> int: return sum(len(arg) for arg in args) - def format_time(self, expression): + def format_time(self, expression: exp.Expression) -> t.Optional[str]: return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) - def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): + def expressions( + self, + expression: exp.Expression, + key: t.Optional[str] = None, + flat: bool = False, + indent: bool = True, + sep: str = ", ", + ) -> str: expressions = expression.args.get(key or "expressions") if not expressions: @@ -1359,45 +1395,67 @@ class Generator: else: result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") - result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) - return self.indent(result_sqls, skip_first=False) if indent else result_sqls + result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) + return self.indent(result_sql, skip_first=False) if indent else result_sql - def op_expressions(self, op, expression, flat=False): + def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: expressions_sql = self.expressions(expression, flat=flat) if flat: return f"{op} {expressions_sql}" return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" - def naked_property(self, expression): + def naked_property(self, expression: exp.Property) -> str: property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) if not property_name: self.unsupported(f"Unsupported property {expression.__class__.__name__}") return f"{property_name} {self.sql(expression, 'this')}" - def set_operation(self, expression, op): + def set_operation(self, expression: exp.Expression, op: str) -> str: this = self.sql(expression, "this") op = self.seg(op) return self.query_modifiers( expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" ) - def token_sql(self, token_type): + def token_sql(self, token_type: TokenType) -> str: return self.TOKEN_MAPPING.get(token_type, token_type.name) - def userdefinedfunction_sql(self, expression): + def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") expressions = self.no_identify(lambda: self.expressions(expression)) return f"{this}({expressions})" - def userdefinedfunctionkwarg_sql(self, expression): + def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind") return f"{this} {kind}" - def joinhint_sql(self, expression): + def joinhint_sql(self, expression: exp.JoinHint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"{this}({expressions})" - def kwarg_sql(self, expression): + def kwarg_sql(self, expression: exp.Kwarg) -> str: return self.binary(expression, "=>") + + def when_sql(self, expression: exp.When) -> str: + this = self.sql(expression, "this") + then_expression = expression.args.get("then") + if isinstance(then_expression, exp.Insert): + then = f"INSERT {self.sql(then_expression, 'this')}" + if "expression" in then_expression.args: + then += f" VALUES {self.sql(then_expression, 'expression')}" + elif isinstance(then_expression, exp.Update): + if isinstance(then_expression.args.get("expressions"), exp.Star): + then = f"UPDATE {self.sql(then_expression, 'expressions')}" + else: + then = f"UPDATE SET {self.expressions(then_expression, flat=True)}" + else: + then = self.sql(then_expression) + return f"WHEN {this} THEN {then}" + + def merge_sql(self, expression: exp.Merge) -> str: + this = self.sql(expression, "this") + using = f"USING {self.sql(expression, 'using')}" + on = f"ON {self.sql(expression, 'on')}" + return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 8c5808d..ed37e6c 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int: except StopIteration: # d.values() returns an empty sequence return 1 + + +def first(it: t.Iterable[T]) -> T: + """Returns the first element from an iterable. + + Useful for sets. + """ + return next(i for i in it) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 191ea52..be17f15 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) - >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" + >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'> Args: @@ -41,9 +41,12 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), exp.Alias: lambda self, expr: self._annotate_unary(expr), + exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Literal: lambda self, expr: self._annotate_literal(expr), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), @@ -52,6 +55,9 @@ class TypeAnnotator: expr, exp.DataType.Type.BIGINT ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), @@ -263,10 +269,10 @@ class TypeAnnotator: } # First annotate the current scope's column references for col in scope.columns: - source = scope.sources[col.table] + source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - else: + elif source: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -280,6 +286,7 @@ class TypeAnnotator: return expression # We've already inferred the expression's type annotator = self.annotators.get(expression.__class__) + return ( annotator(self, expression) if annotator @@ -295,18 +302,23 @@ class TypeAnnotator: def _maybe_coerce(self, type1, type2): # We propagate the NULL / UNKNOWN types upwards if found + if isinstance(type1, exp.DataType): + type1 = type1.this + if isinstance(type2, exp.DataType): + type2 = type2.this + if exp.DataType.Type.NULL in (type1, type2): return exp.DataType.Type.NULL if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to[type1] else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - left_type = expression.left.type - right_type = expression.right.type + left_type = expression.left.type.this + right_type = expression.right.type.this if isinstance(expression, (exp.And, exp.Or)): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -348,7 +360,7 @@ class TypeAnnotator: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args): + def _annotate_by_args(self, expression, *args, promote=False): self._annotate_args(expression) expressions = [] for arg in args: @@ -360,4 +372,11 @@ class TypeAnnotator: last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) expression.type = last_datatype or exp.DataType.Type.UNKNOWN + + if promote: + if expression.type.this in exp.DataType.INTEGER_TYPES: + expression.type = exp.DataType.Type.BIGINT + elif expression.type.this in exp.DataType.FLOAT_TYPES: + expression.type = exp.DataType.Type.DOUBLE + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 9b3d98a..33529a5 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression: The expression to canonicalize. """ exp.replace_children(expression, canonicalize) + expression = add_text_to_concat(expression) expression = coerce_type(expression) + expression = remove_redundant_casts(expression) + return expression def add_text_to_concat(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: + if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: node = exp.Concat(this=node.this, expression=node.expression) return node @@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression: elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) elif isinstance(node, exp.Extract): - if node.expression.type not in exp.DataType.TEMPORAL_TYPES: + if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: _replace_cast(node.expression, "datetime") return node +def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Cast) + and expression.to.type + and expression.this.type + and expression.to.type.this == expression.this.type.this + ): + return expression.this + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): - if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: + if ( + a.type + and a.type.this == exp.DataType.Type.DATE + and b.type + and b.type.this != exp.DataType.Type.DATE + ): _replace_cast(b, "date") diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c432c59..c0719f2 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -7,7 +7,7 @@ from decimal import Decimal from sqlglot import exp from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.generator import Generator -from sqlglot.helper import while_changing +from sqlglot.helper import first, while_changing GENERATOR = Generator(normalize=True, identify=True) @@ -30,6 +30,7 @@ def simplify(expression): def _simplify(expression, root=True): node = expression + node = rewrite_between(node) node = uniq_sort(node) node = absorb_and_eliminate(node) exp.replace_children(node, lambda e: _simplify(e, False)) @@ -49,6 +50,19 @@ def simplify(expression): return expression +def rewrite_between(expression: exp.Expression) -> exp.Expression: + """Rewrite x between y and z to x >= y AND x <= z. + + This is done because comparison simplification is only done on lt/lte/gt/gte. + """ + if isinstance(expression, exp.Between): + return exp.and_( + exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), + exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), + ) + return expression + + def simplify_not(expression): """ Demorgan's Law @@ -57,7 +71,7 @@ def simplify_not(expression): """ if isinstance(expression, exp.Not): if isinstance(expression.this, exp.Null): - return NULL + return exp.null() if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): @@ -65,11 +79,11 @@ def simplify_not(expression): if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Null): - return NULL + return exp.null() if always_true(expression.this): - return FALSE + return exp.false() if expression.this == FALSE: - return TRUE + return exp.true() if isinstance(expression.this, exp.Not): # double negation # NOT NOT x -> x @@ -91,40 +105,119 @@ def flatten(expression): def simplify_connectors(expression): - if isinstance(expression, exp.Connector): - left = expression.left - right = expression.right - - if left == right: - return left - - if isinstance(expression, exp.And): - if FALSE in (left, right): - return FALSE - if NULL in (left, right): - return NULL - if always_true(left) and always_true(right): - return TRUE - if always_true(left): - return right - if always_true(right): - return left - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return TRUE - if left == FALSE and right == FALSE: - return FALSE - if ( - (left == NULL and right == NULL) - or (left == NULL and right == FALSE) - or (left == FALSE and right == NULL) - ): - return NULL - if left == FALSE: - return right - if right == FALSE: + def _simplify_connectors(expression, left, right): + if isinstance(expression, exp.Connector): + if left == right: return left - return expression + if isinstance(expression, exp.And): + if FALSE in (left, right): + return exp.false() + if NULL in (left, right): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): + return left + return _simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if left == FALSE and right == FALSE: + return exp.false() + if ( + (left == NULL and right == NULL) + or (left == NULL and right == FALSE) + or (left == FALSE and right == NULL) + ): + return exp.null() + if left == FALSE: + return right + if right == FALSE: + return left + return _simplify_comparison(expression, left, right, or_=True) + return None + + return _flat_simplify(expression, _simplify_connectors) + + +LT_LTE = (exp.LT, exp.LTE) +GT_GTE = (exp.GT, exp.GTE) + +COMPARISONS = ( + *LT_LTE, + *GT_GTE, + exp.EQ, + exp.NEQ, +) + +INVERSE_COMPARISONS = { + exp.LT: exp.GT, + exp.GT: exp.LT, + exp.LTE: exp.GTE, + exp.GTE: exp.LTE, +} + + +def _simplify_comparison(expression, left, right, or_=False): + if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): + ll, lr = left.args.values() + rl, rr = right.args.values() + + largs = {ll, lr} + rargs = {rl, rr} + + matching = largs & rargs + columns = {m for m in matching if isinstance(m, exp.Column)} + + if matching and columns: + try: + l = first(largs - columns) + r = first(rargs - columns) + except StopIteration: + return expression + + # make sure the comparison is always of the form x > 1 instead of 1 < x + if left.__class__ in INVERSE_COMPARISONS and l == ll: + left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) + if right.__class__ in INVERSE_COMPARISONS and r == rl: + right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) + + if l.is_number and r.is_number: + l = float(l.name) + r = float(r.name) + elif l.is_string and r.is_string: + l = l.name + r = r.name + else: + return None + + for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): + if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): + return left if (av > bv if or_ else av <= bv) else right + if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): + return left if (av < bv if or_ else av >= bv) else right + + # we can't ever shortcut to true because the column could be null + if isinstance(a, exp.LT) and isinstance(b, GT_GTE): + if not or_ and av <= bv: + return exp.false() + elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): + if not or_ and av >= bv: + return exp.false() + elif isinstance(a, exp.EQ): + if isinstance(b, exp.LT): + return exp.false() if av >= bv else a + if isinstance(b, exp.LTE): + return exp.false() if av > bv else a + if isinstance(b, exp.GT): + return exp.false() if av <= bv else a + if isinstance(b, exp.GTE): + return exp.false() if av < bv else a + if isinstance(b, exp.NEQ): + return exp.false() if av == bv else a + return None def remove_compliments(expression): @@ -135,7 +228,7 @@ def remove_compliments(expression): A OR NOT A -> TRUE """ if isinstance(expression, exp.Connector): - compliment = FALSE if isinstance(expression, exp.And) else TRUE + compliment = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): if is_complement(a, b): @@ -211,27 +304,7 @@ def absorb_and_eliminate(expression): def simplify_literals(expression): if isinstance(expression, exp.Binary): - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) - - while queue: - a = queue.popleft() - - for b in queue: - result = _simplify_binary(expression, a, b) - - if result: - queue.remove(b) - queue.append(result) - break - else: - operands.append(a) - - if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) + return _flat_simplify(expression, _simplify_binary) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b): if c == NULL: if isinstance(a, exp.Literal): - return TRUE if not_ else FALSE + return exp.true() if not_ else exp.false() if a == NULL: - return FALSE if not_ else TRUE - elif isinstance(expression, exp.NullSafeEQ): - if a == b: - return TRUE - elif isinstance(expression, exp.NullSafeNEQ): - if a == b: - return FALSE + return exp.false() if not_ else exp.true() + elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): + return None elif NULL in (a, b): - return NULL - - if isinstance(expression, exp.EQ) and a == b: - return TRUE + return exp.null() if a.is_number and b.is_number: a = int(a.name) if a.is_int else Decimal(a.name) @@ -388,4 +454,27 @@ def date_literal(date): def boolean_literal(condition): - return TRUE if condition else FALSE + return exp.true() if condition else exp.false() + + +def _flat_simplify(expression, simplifier): + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = simplifier(expression, a, b) + + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + return expression diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bdf0d2d..55ab453 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -185,6 +185,7 @@ class Parser(metaclass=_Parser): TokenType.LOCAL, TokenType.LOCATION, TokenType.MATERIALIZED, + TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, @@ -211,7 +212,6 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TABLE_FORMAT, TokenType.TEMPORARY, - TokenType.TRANSIENT, TokenType.TOP, TokenType.TRAILING, TokenType.TRUE, @@ -229,6 +229,8 @@ class Parser(metaclass=_Parser): TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} + UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} + TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { @@ -241,6 +243,7 @@ class Parser(metaclass=_Parser): TokenType.FORMAT, TokenType.IDENTIFIER, TokenType.ISNULL, + TokenType.MERGE, TokenType.OFFSET, TokenType.PRIMARY_KEY, TokenType.REPLACE, @@ -407,6 +410,7 @@ class Parser(metaclass=_Parser): TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), + TokenType.MERGE: lambda self: self._parse_merge(), } UNARY_PARSERS = { @@ -474,6 +478,7 @@ class Parser(metaclass=_Parser): TokenType.SORTKEY: lambda self: self._parse_sortkey(), TokenType.LIKE: lambda self: self._parse_create_like(), TokenType.RETURNS: lambda self: self._parse_returns(), + TokenType.ROW: lambda self: self._parse_row(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), @@ -495,6 +500,8 @@ class Parser(metaclass=_Parser): TokenType.VOLATILE: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") ), + TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property), + TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property), } CONSTRAINT_PARSERS = { @@ -802,7 +809,8 @@ class Parser(metaclass=_Parser): def _parse_create(self): replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) - transient = self._match(TokenType.TRANSIENT) + transient = self._match_text_seq("TRANSIENT") + external = self._match_text_seq("EXTERNAL") unique = self._match(TokenType.UNIQUE) materialized = self._match(TokenType.MATERIALIZED) @@ -846,6 +854,7 @@ class Parser(metaclass=_Parser): properties=properties, temporary=temporary, transient=transient, + external=external, replace=replace, unique=unique, materialized=materialized, @@ -861,8 +870,12 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): return self._parse_sortkey(compound=True) - if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): - key = self._parse_var() + assignment = self._match_pair( + TokenType.VAR, TokenType.EQ, advance=False + ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) + + if assignment: + key = self._parse_var() or self._parse_string() self._match(TokenType.EQ) return self.expression(exp.Property, this=key, value=self._parse_column()) @@ -871,7 +884,10 @@ class Parser(metaclass=_Parser): def _parse_property_assignment(self, exp_class): self._match(TokenType.EQ) self._match(TokenType.ALIAS) - return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number()) + return self.expression( + exp_class, + this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), + ) def _parse_partitioned_by(self): self._match(TokenType.EQ) @@ -881,7 +897,7 @@ class Parser(metaclass=_Parser): ) def _parse_distkey(self): - return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var)) + return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) def _parse_create_like(self): table = self._parse_table(schema=True) @@ -898,7 +914,7 @@ class Parser(metaclass=_Parser): def _parse_sortkey(self, compound=False): return self.expression( - exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound + exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound ) def _parse_character_set(self, default=False): @@ -929,23 +945,11 @@ class Parser(metaclass=_Parser): properties = [] while True: - if self._match(TokenType.WITH): - properties.extend(self._parse_wrapped_csv(self._parse_property)) - elif self._match(TokenType.PROPERTIES): - properties.extend( - self._parse_wrapped_csv( - lambda: self.expression( - exp.Property, - this=self._parse_string(), - value=self._match(TokenType.EQ) and self._parse_string(), - ) - ) - ) - else: - identified_property = self._parse_property() - if not identified_property: - break - properties.append(identified_property) + identified_property = self._parse_property() + if not identified_property: + break + for p in ensure_collection(identified_property): + properties.append(p) if properties: return self.expression(exp.Properties, expressions=properties) @@ -963,7 +967,7 @@ class Parser(metaclass=_Parser): exp.Directory, this=self._parse_var_or_string(), local=local, - row_format=self._parse_row_format(), + row_format=self._parse_row_format(match_row=True), ) else: self._match(TokenType.INTO) @@ -978,10 +982,18 @@ class Parser(metaclass=_Parser): overwrite=overwrite, ) - def _parse_row_format(self): - if not self._match_pair(TokenType.ROW, TokenType.FORMAT): + def _parse_row(self): + if not self._match(TokenType.FORMAT): + return None + return self._parse_row_format() + + def _parse_row_format(self, match_row=False): + if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None + if self._match_text_seq("SERDE"): + return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string()) + self._match_text_seq("DELIMITED") kwargs = {} @@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser): kwargs["lines"] = self._parse_string() if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() - return self.expression(exp.RowFormat, **kwargs) + return self.expression(exp.RowFormatDelimitedProperty, **kwargs) def _parse_load_data(self): local = self._match(TokenType.LOCAL) @@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Update, **{ - "this": self._parse_table(schema=True), + "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), "where": self._parse_where(), @@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser): alias=alias, ) - def _parse_table_alias(self): + def _parse_table_alias(self, alias_tokens=None): any_token = self._match(TokenType.ALIAS) - alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS) + alias = self._parse_id_var( + any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS + ) columns = None if self._match(TokenType.L_PAREN): @@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser): columns=self._parse_expression(), ) - def _parse_table(self, schema=False): + def _parse_table(self, schema=False, alias_tokens=None): lateral = self._parse_lateral() if lateral: @@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser): table = self._parse_id_var() if not table: - self.raise_error("Expected table name") + self.raise_error(f"Expected table name but got {self._curr}") this = self.expression( exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() @@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser): if self.alias_post_tablesample: table_sample = self._parse_table_sample() - alias = self._parse_table_alias() + alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) if alias: this.set("alias", alias) @@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser): kind = self.expression(exp.CheckColumnConstraint, this=constraint) elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) + elif self._match(TokenType.ENCODE): + kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() + elif self._match(TokenType.NULL): + kind = exp.NotNullColumnConstraint(allow_null=True) elif self._match(TokenType.SCHEMA_COMMENT): kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): @@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser): return self._parse_window(this) def _parse_extract(self): - this = self._parse_var() or self._parse_type() + this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) @@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser): parser = self._find_parser(self.SET_PARSERS, self._set_trie) return parser(self) if parser else self._default_parse_set_item() + def _parse_merge(self): + self._match(TokenType.INTO) + target = self._parse_table(schema=True) + + self._match(TokenType.USING) + using = self._parse_table() + + self._match(TokenType.ON) + on = self._parse_conjunction() + + whens = [] + while self._match(TokenType.WHEN): + this = self._parse_conjunction() + self._match(TokenType.THEN) + + if self._match(TokenType.INSERT): + _this = self._parse_star() + if _this: + then = self.expression(exp.Insert, this=_this) + else: + then = self.expression( + exp.Insert, + this=self._parse_value(), + expression=self._match(TokenType.VALUES) and self._parse_value(), + ) + elif self._match(TokenType.UPDATE): + expressions = self._parse_star() + if expressions: + then = self.expression(exp.Update, expressions=expressions) + else: + then = self.expression( + exp.Update, + expressions=self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + ) + elif self._match(TokenType.DELETE): + then = self.expression(exp.Var, this=self._prev.text) + + whens.append(self.expression(exp.When, this=this, then=then)) + + return self.expression( + exp.Merge, + this=target, + using=using, + on=on, + expressions=whens, + ) + def _parse_set(self): return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f6f303b..8a264a2 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -47,7 +47,7 @@ class Schema(abc.ABC): """ @abc.abstractmethod - def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: + def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: """ Get the :class:`sqlglot.exp.DataType` type of a column in the schema. @@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): super().__init__(schema) self.visible = visible or {} self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = { - "STR": exp.DataType.Type.TEXT, + self._type_mapping_cache: t.Dict[str, exp.DataType] = { + "STR": exp.DataType.build("text"), } @classmethod @@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): visible = self._nested_get(self.table_parts(table_), self.visible) return [col for col in schema if col in visible] # type: ignore - def get_column_type( - self, table: exp.Table | str, column: exp.Column | str - ) -> exp.DataType.Type: + def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: column_name = column if isinstance(column, str) else column.name table_ = exp.to_table(table) if table_: - table_schema = self.find(table_) - schema_type = table_schema.get(column_name).upper() # type: ignore - return self._convert_type(schema_type) + table_schema = self.find(table_, raise_on_missing=False) + if table_schema: + schema_type = table_schema.get(column_name).upper() # type: ignore + return self._convert_type(schema_type) + return exp.DataType(this=exp.DataType.Type.UNKNOWN) raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type: str) -> exp.DataType.Type: + def _convert_type(self, schema_type: str) -> exp.DataType: """ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. @@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) if expression is None: raise ValueError(f"Could not parse {schema_type}") - self._type_mapping_cache[schema_type] = expression.this + self._type_mapping_cache[schema_type] = expression # type: ignore except AttributeError: raise SchemaError(f"Failed to convert type {schema_type}") diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8a7a38e..b25ef8d 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -49,6 +49,9 @@ class TokenType(AutoName): PARAMETER = auto() SESSION_PARAMETER = auto() + BLOCK_START = auto() + BLOCK_END = auto() + SPACE = auto() BREAK = auto() @@ -156,6 +159,7 @@ class TokenType(AutoName): DIV = auto() DROP = auto() ELSE = auto() + ENCODE = auto() END = auto() ENGINE = auto() ESCAPE = auto() @@ -207,6 +211,7 @@ class TokenType(AutoName): LOCATION = auto() MAP = auto() MATERIALIZED = auto() + MERGE = auto() MOD = auto() NATURAL = auto() NEXT = auto() @@ -255,6 +260,7 @@ class TokenType(AutoName): SELECT = auto() SEMI = auto() SEPARATOR = auto() + SERDE_PROPERTIES = auto() SET = auto() SHOW = auto() SIMILAR_TO = auto() @@ -267,7 +273,6 @@ class TokenType(AutoName): TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() - TRANSIENT = auto() TOP = auto() THEN = auto() TRAILING = auto() @@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer): ESCAPES = ["'"] KEYWORDS = { + **{ + f"{key}{postfix}": TokenType.BLOCK_START + for key in ("{{", "{%", "{#") + for postfix in ("", "+", "-") + }, + **{ + f"{prefix}{key}": TokenType.BLOCK_END + for key in ("}}", "%}", "#}") + for prefix in ("", "+", "-") + }, "/*+": TokenType.HINT, "==": TokenType.EQ, "::": TokenType.DCOLON, @@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer): "LOCAL": TokenType.LOCAL, "LOCATION": TokenType.LOCATION, "MATERIALIZED": TokenType.MATERIALIZED, + "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NO ACTION": TokenType.NO_ACTION, @@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer): "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY, - "TRANSIENT": TokenType.TRANSIENT, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "TRAILING": TokenType.TRAILING, diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index 32ff8f2..2dcdb39 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -4,6 +4,7 @@ import unittest from sqlglot.dataframe.sql import types from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.helper import ensure_list class DataFrameSQLValidator(unittest.TestCase): @@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase): self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False ): actual_sqls = df.sql(pretty=pretty) - expected_statements = ( - [expected_statements] if isinstance(expected_statements, str) else expected_statements - ) + expected_statements = ensure_list(expected_statements) self.assertEqual(len(expected_statements), len(actual_sqls)) for expected, actual in zip(expected_statements, actual_sqls): self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 7c646f5..042b915 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -10,37 +10,37 @@ class TestDataFrameWriter(DataFrameSQLValidator): def test_insertInto_full_path(self): df = self.df_employee.write.insertInto("catalog.db.table_name") - expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_db_table(self): df = self.df_employee.write.insertInto("db.table_name") - expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_table(self): df = self.df_employee.write.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_overwrite(self): df = self.df_employee.write.insertInto("table_name", overwrite=True) - expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) @mock.patch("sqlglot.schema", MappingSchema()) def test_insertInto_byName(self): sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) df = self.df_employee.write.byName.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_cache(self): df = self.df_employee.cache().write.insertInto("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t37164", - "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", + "DROP VIEW IF EXISTS t12441", + "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`", ] self.compare_sql(df, expected_statements) @@ -50,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator): def test_saveAsTable_append(self): df = self.df_employee.write.saveAsTable("table_name", mode="append") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_overwrite(self): df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_error(self): df = self.df_employee.write.saveAsTable("table_name", mode="error") - expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_ignore(self): df = self.df_employee.write.saveAsTable("table_name", mode="ignore") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_standalone(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_override(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_cache(self): df = self.df_employee.cache().write.saveAsTable("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t37164", - "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", + "DROP VIEW IF EXISTS t12441", + "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`", ] self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 55aa547..5213667 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -36,7 +36,7 @@ class TestDataframeSession(DataFrameSQLValidator): def test_cdf_str_schema(self): df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") - expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_basic(self): @@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator): ] ) df = self.spark.createDataFrame([[1, "test"]], schema) - expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_nested(self): diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index cc44311..1d60ec6 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -7,6 +7,11 @@ class TestBigQuery(Validator): def test_bigquery(self): self.validate_all( + "REGEXP_CONTAINS('foo', '.*')", + read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, + write={"mysql": "REGEXP_LIKE('foo', '.*')"}, + ), + self.validate_all( '"""x"""', write={ "bigquery": "'x'", @@ -94,6 +99,20 @@ class TestBigQuery(Validator): "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", }, ) + self.validate_all( + "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)", + write={"bigquery": "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)"}, + ) + self.validate_all( + "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers", + write={ + "bigquery": "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers" + }, + ) + self.validate_all( + "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)", + write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"}, + ) self.validate_all( "x IS unknown", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6033570..ee67bf1 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1318,3 +1318,39 @@ SELECT "BEGIN IMMEDIATE TRANSACTION", write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"}, ) + + def test_merge(self): + self.validate_all( + """ + MERGE INTO target USING source ON target.id = source.id + WHEN NOT MATCHED THEN INSERT (id) values (source.id) + """, + write={ + "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)", + "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)", + "spark": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)", + }, + ) + self.validate_all( + """ + MERGE INTO target USING source ON target.id = source.id + WHEN MATCHED AND source.is_deleted = 1 THEN DELETE + WHEN MATCHED THEN UPDATE SET val = source.val + WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val) + """, + write={ + "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)", + "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)", + "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)", + }, + ) + self.validate_all( + """ + MERGE INTO target USING source ON target.id = source.id + WHEN MATCHED THEN UPDATE * + WHEN NOT MATCHED THEN INSERT * + """, + write={ + "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *", + }, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 22d7bce..5ac8714 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -145,6 +145,10 @@ class TestHive(Validator): }, ) + self.validate_identity( + """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""", + ) + def test_lateral_view(self): self.validate_all( "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index cd6117c..962b28b 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -256,3 +256,7 @@ class TestPostgres(Validator): "SELECT $$Dianne's horse$$", write={"postgres": "SELECT 'Dianne''s horse'"}, ) + self.validate_all( + "UPDATE MYTABLE T1 SET T1.COL = 13", + write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"}, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 1943ee3..3034df5 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -56,8 +56,27 @@ class TestRedshift(Validator): "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', }, ) + self.validate_all( + "DECODE(x, a, b, c, d)", + write={ + "": "MATCHES(x, a, b, c, d)", + "oracle": "DECODE(x, a, b, c, d)", + "snowflake": "DECODE(x, a, b, c, d)", + }, + ) + self.validate_all( + "NVL(a, b, c, d)", + write={ + "redshift": "COALESCE(a, b, c, d)", + "mysql": "COALESCE(a, b, c, d)", + "postgres": "COALESCE(a, b, c, d)", + }, + ) def test_identity(self): + self.validate_identity( + "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')" + ) self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CREATE TABLE real1 (realcol REAL)") self.validate_identity("CAST('foo' AS HLLSKETCH)") @@ -70,9 +89,9 @@ class TestRedshift(Validator): self.validate_identity( "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" ) - self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO") + self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL") self.validate_identity( - "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)" + "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" ) self.validate_identity( "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" @@ -80,3 +99,6 @@ class TestRedshift(Validator): self.validate_identity( "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" ) + self.validate_identity( + "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index baca269..bca5aaa 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -500,3 +500,12 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL F }, pretty=True, ) + + def test_minus(self): + self.validate_all( + "SELECT 1 EXCEPT SELECT 1", + read={ + "oracle": "SELECT 1 MINUS SELECT 1", + "snowflake": "SELECT 1 MINUS SELECT 1", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 06ab96d..e12b673 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -75,6 +75,7 @@ ARRAY(1, 2) ARRAY_CONTAINS(x, 1) EXTRACT(x FROM y) EXTRACT(DATE FROM y) +EXTRACT(WEEK(monday) FROM created_at) CONCAT_WS('-', 'a', 'b') CONCAT_WS('-', 'a', 'b', 'c') POSEXPLODE("x") AS ("a", "b") diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 7fcdbb8..8880881 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -3,3 +3,9 @@ SELECT CONCAT(w.d, w.e) AS c FROM w AS w; SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; + +SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; +SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; + +SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; +SELECT 1 + 3.2 AS a FROM w AS w; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index d9c7779..cf4195d 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -79,14 +79,16 @@ NULL; NULL = NULL; NULL; +-- Can't optimize this because different engines do different things +-- mysql converts to 0 and 1 but tsql does true and false NULL <=> NULL; -TRUE; +NULL IS NOT DISTINCT FROM NULL; a IS NOT DISTINCT FROM a; -TRUE; +a IS NOT DISTINCT FROM a; NULL IS DISTINCT FROM NULL; -FALSE; +NULL IS DISTINCT FROM NULL; NOT (NOT TRUE); TRUE; @@ -239,10 +241,10 @@ TRUE; FALSE; ((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3); -TRUE; +x = x; ((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2); -TRUE; +x = x; (('a' = 'a') AND TRUE and NOT FALSE); TRUE; @@ -372,3 +374,171 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; date '1998-12-01' + interval '90' foo; CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; + +-------------------------------------- +-- Comparisons +-------------------------------------- +x < 0 OR x > 1; +x < 0 OR x > 1; + +x < 0 OR x > 0; +x < 0 OR x > 0; + +x < 1 OR x > 0; +x < 1 OR x > 0; + +x < 1 OR x >= 0; +x < 1 OR x >= 0; + +x <= 1 OR x > 0; +x <= 1 OR x > 0; + +x <= 1 OR x >= 0; +x <= 1 OR x >= 0; + +x <= 1 AND x <= 0; +x <= 0; + +x <= 1 AND x > 0; +x <= 1 AND x > 0; + +x <= 1 OR x > 0; +x <= 1 OR x > 0; + +x <= 0 OR x < 0; +x <= 0; + +x >= 0 OR x > 0; +x >= 0; + +x >= 0 OR x > 1; +x >= 0; + +x <= 0 OR x >= 0; +x <= 0 OR x >= 0; + +x <= 0 AND x >= 0; +x <= 0 AND x >= 0; + +x < 1 AND x < 2; +x < 1; + +x < 1 OR x < 2; +x < 2; + +x < 2 AND x < 1; +x < 1; + +x < 2 OR x < 1; +x < 2; + +x < 1 AND x < 1; +x < 1; + +x < 1 OR x < 1; +x < 1; + +x <= 1 AND x < 1; +x < 1; + +x <= 1 OR x < 1; +x <= 1; + +x < 1 AND x <= 1; +x < 1; + +x < 1 OR x <= 1; +x <= 1; + +x > 1 AND x > 2; +x > 2; + +x > 1 OR x > 2; +x > 1; + +x > 2 AND x > 1; +x > 2; + +x > 2 OR x > 1; +x > 1; + +x > 1 AND x > 1; +x > 1; + +x > 1 OR x > 1; +x > 1; + +x >= 1 AND x > 1; +x > 1; + +x >= 1 OR x > 1; +x >= 1; + +x > 1 AND x >= 1; +x > 1; + +x > 1 OR x >= 1; +x >= 1; + +x > 1 AND x >= 2; +x >= 2; + +x > 1 OR x >= 2; +x > 1; + +x > 1 AND x >= 2 AND x > 3 AND x > 0; +x > 3; + +(x > 1 AND x >= 2 AND x > 3 AND x > 0) OR x > 0; +x > 0; + +x > 1 AND x < 2 AND x > 3; +FALSE; + +x > 1 AND x < 1; +FALSE; + +x < 2 AND x > 1; +x < 2 AND x > 1; + +x = 1 AND x < 1; +FALSE; + +x = 1 AND x < 1.1; +x = 1; + +x = 1 AND x <= 1; +x = 1; + +x = 1 AND x <= 0.9; +FALSE; + +x = 1 AND x > 0.9; +x = 1; + +x = 1 AND x > 1; +FALSE; + +x = 1 AND x >= 1; +x = 1; + +x = 1 AND x >= 2; +FALSE; + +x = 1 AND x <> 2; +x = 1; + +x <> 1 AND x = 1; +FALSE; + +x BETWEEN 0 AND 5 AND x > 3; +x <= 5 AND x > 3; + +x > 3 AND 5 > x AND x BETWEEN 0 AND 10; +x < 5 AND x > 3; + +x > 3 AND 5 < x AND x BETWEEN 9 AND 10; +x <= 10 AND x >= 9; + +1 < x AND 3 < x; +x > 3; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 4893743..9c1f138 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -190,7 +190,7 @@ SELECT SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue", - CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", + "orders"."o_orderdate" AS "o_orderdate", "orders"."o_shippriority" AS "o_shippriority" FROM "customer" AS "customer" JOIN "orders" AS "orders" @@ -326,7 +326,8 @@ SELECT SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue" FROM "lineitem" AS "lineitem" WHERE - "lineitem"."l_discount" BETWEEN 0.05 AND 0.07 + "lineitem"."l_discount" <= 0.07 + AND "lineitem"."l_discount" >= 0.05 AND "lineitem"."l_quantity" < 24 AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE); @@ -344,7 +345,7 @@ from select n1.n_name as supp_nation, n2.n_name as cust_nation, - extract(year from l_shipdate) as l_year, + extract(year from cast(l_shipdate as date)) as l_year, l_extendedprice * (1 - l_discount) as volume from supplier, @@ -384,13 +385,14 @@ WITH "n1" AS ( SELECT "n1"."n_name" AS "supp_nation", "n2"."n_name" AS "cust_nation", - EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year", + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year", SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" FROM "supplier" AS "supplier" JOIN "lineitem" AS "lineitem" - ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ON CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1996-12-31' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-01-01' AS DATE) AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" JOIN "orders" AS "orders" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -409,7 +411,7 @@ JOIN "n1" AS "n2" GROUP BY "n1"."n_name", "n2"."n_name", - EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) ORDER BY "supp_nation", "cust_nation", @@ -427,7 +429,7 @@ select from ( select - extract(year from o_orderdate) as o_year, + extract(year from cast(o_orderdate as date)) as o_year, l_extendedprice * (1 - l_discount) as volume, n2.n_name as nation from @@ -456,7 +458,7 @@ group by order by o_year; SELECT - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", SUM( CASE WHEN "nation_2"."n_name" = 'BRAZIL' @@ -477,7 +479,8 @@ JOIN "customer" AS "customer" ON "customer"."c_nationkey" = "nation"."n_nationkey" JOIN "orders" AS "orders" ON "orders"."o_custkey" = "customer"."c_custkey" - AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) <= CAST('1996-12-31' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1995-01-01' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "part"."p_partkey" = "lineitem"."l_partkey" @@ -488,7 +491,7 @@ JOIN "nation" AS "nation_2" WHERE "part"."p_type" = 'ECONOMY ANODIZED STEEL' GROUP BY - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) ORDER BY "o_year"; @@ -503,7 +506,7 @@ from ( select n_name as nation, - extract(year from o_orderdate) as o_year, + extract(year from cast(o_orderdate as date)) as o_year, l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount from part, @@ -529,7 +532,7 @@ order by o_year desc; SELECT "nation"."n_name" AS "nation", - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", SUM( "lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" @@ -551,7 +554,7 @@ WHERE "part"."p_name" LIKE '%green%' GROUP BY "nation"."n_name", - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) ORDER BY "nation", "o_year" DESC; @@ -1016,7 +1019,7 @@ select o_orderkey, o_orderdate, o_totalprice, - sum(l_quantity) + sum(l_quantity) total_quantity from customer, orders, @@ -1060,7 +1063,7 @@ SELECT "orders"."o_orderkey" AS "o_orderkey", "orders"."o_orderdate" AS "o_orderdate", "orders"."o_totalprice" AS "o_totalprice", - SUM("lineitem"."l_quantity") AS "_col_5" + SUM("lineitem"."l_quantity") AS "total_quantity" FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" @@ -1129,19 +1132,22 @@ JOIN "part" AS "part" "part"."p_brand" = 'Brand#12' AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 5 + AND "part"."p_size" <= 5 + AND "part"."p_size" >= 1 ) OR ( "part"."p_brand" = 'Brand#23' AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 10 + AND "part"."p_size" <= 10 + AND "part"."p_size" >= 1 ) OR ( "part"."p_brand" = 'Brand#34' AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 15 + AND "part"."p_size" <= 15 + AND "part"."p_size" >= 1 ) WHERE ( @@ -1152,7 +1158,8 @@ WHERE AND "part"."p_brand" = 'Brand#12' AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 5 + AND "part"."p_size" <= 5 + AND "part"."p_size" >= 1 ) OR ( "lineitem"."l_quantity" <= 20 @@ -1162,7 +1169,8 @@ WHERE AND "part"."p_brand" = 'Brand#23' AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 10 + AND "part"."p_size" <= 10 + AND "part"."p_size" >= 1 ) OR ( "lineitem"."l_quantity" <= 30 @@ -1172,7 +1180,8 @@ WHERE AND "part"."p_brand" = 'Brand#34' AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 15 + AND "part"."p_size" <= 15 + AND "part"."p_size" >= 1 ); -------------------------------------- diff --git a/tests/test_executor.py b/tests/test_executor.py index 9d452e4..4fe6399 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -26,12 +26,12 @@ class TestExecutor(unittest.TestCase): def setUpClass(cls): cls.conn = duckdb.connect() - for table in TPCH_SCHEMA: + for table, columns in TPCH_SCHEMA.items(): cls.conn.execute( f""" CREATE VIEW {table} AS SELECT * - FROM READ_CSV_AUTO('{DIR}{table}.csv.gz') + FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns}) """ ) @@ -74,13 +74,13 @@ class TestExecutor(unittest.TestCase): ) return expression - for i, (sql, _) in enumerate(self.sqls[0:16]): + for i, (sql, _) in enumerate(self.sqls[0:18]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) table = execute(sql, TPCH_SCHEMA) b = pd.DataFrame(table.rows, columns=table.columns) - assert_frame_equal(a, b, check_dtype=False) + assert_frame_equal(a, b, check_dtype=False, check_index_type=False) def test_execute_callable(self): tables = { @@ -456,11 +456,16 @@ class TestExecutor(unittest.TestCase): ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), - ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]), + ( + "SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)", + ["_col_0", "_col_1"], + [(None, 0)], + ), ]: - result = execute(sql) - self.assertEqual(result.columns, tuple(cols)) - self.assertEqual(result.rows, rows) + with self.subTest(sql): + result = execute(sql) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(result.rows, rows) def test_aggregate_without_group_by(self): result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]}) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index ecf581d..0c5f6cd 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for sql, target_type in tests.items(): expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Literal).type, target_type) + self.assertEqual(expression.find(exp.Literal).type.this, target_type) def test_boolean_type_annotation(self): tests = { @@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for sql, target_type in tests.items(): expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Boolean).type, target_type) + self.assertEqual(expression.find(exp.Boolean).type.this, target_type) def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) + self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT) - self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) - self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) + expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>")) + self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType)) def test_cache_annotation(self): expression = annotate_types( parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") ) - self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) + self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT) def test_binary_annotation(self): expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0] - self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) - self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) - self.assertEqual(expression.right.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) def test_derived_tables_column_annotation(self): schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}} @@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = annotate_types(parse_one(sql), schema=schema) - self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola + self.assertEqual( + expression.expressions[0].type.this, exp.DataType.Type.FLOAT + ) # a.cola AS cola addition_alias = expression.args["from"].expressions[0].this.expressions[0] - self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola + self.assertEqual( + addition_alias.type.this, exp.DataType.Type.FLOAT + ) # x.cola + y.cola AS cola addition = addition_alias.this - self.assertEqual(addition.type, exp.DataType.Type.FLOAT) - self.assertEqual(addition.this.type, exp.DataType.Type.INT) - self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT) + self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT) + self.assertEqual(addition.this.type.this, exp.DataType.Type.INT) + self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT) def test_cte_column_annotation(self): - schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}} + schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}} sql = """ WITH tbl AS ( - SELECT x.cola + 'bla' AS cola, y.colb AS colb + SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc FROM ( SELECT x.cola AS cola FROM x AS x ) AS x JOIN ( - SELECT y.colb AS colb + SELECT y.colb AS colb, y.colc AS colc FROM y AS y ) AS y ) SELECT tbl.cola + tbl.colb + 'foo' AS col FROM tbl AS tbl + WHERE tbl.colc = True """ expression = annotate_types(parse_one(sql), schema=schema) self.assertEqual( - expression.expressions[0].type, exp.DataType.Type.TEXT + expression.expressions[0].type.this, exp.DataType.Type.TEXT ) # tbl.cola + tbl.colb + 'foo' AS col outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' - self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) - self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT) - self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR) + self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR) inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb - self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR) - self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) + self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT) + + # WHERE tbl.colc = True + self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN) cte_select = expression.args["with"].expressions[0].this self.assertEqual( - cte_select.expressions[0].type, exp.DataType.Type.VARCHAR + cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola + 'bla' AS cola - self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb + self.assertEqual( + cte_select.expressions[1].type.this, exp.DataType.Type.TEXT + ) # y.colb AS colb + self.assertEqual( + cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN + ) # y.colc AS colc cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' - self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR) - self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR) - self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR) + self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR) # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively for d, t in zip( cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] ): - self.assertEqual(d.this.expressions[0].this.type, t) + self.assertEqual(d.this.expressions[0].this.type.this, t) def test_function_annotation(self): schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola - self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) - self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb) + self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR) case_expr = case_expr_alias.this - self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) + self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR) case_ifs_expr = case_expr.args["ifs"][0] - self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR) def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola self.assertEqual( - concat_expr.right.type, exp.DataType.Type.UNKNOWN + concat_expr.right.type.this, exp.DataType.Type.UNKNOWN ) # SOME_ANONYMOUS_FUNC(x.cola) self.assertEqual( - concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR + concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola (arg) def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this - self.assertEqual(expression.left.type, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) # NULL <op> UNKNOWN should yield NULL sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN) def test_nullable_annotation(self): nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) expression = annotate_types(parse_one("NULL AND FALSE")) self.assertEqual(expression.type, nullable) - self.assertEqual(expression.left.type, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN) + self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN) + + def test_predicate_annotation(self): + expression = annotate_types(parse_one("x BETWEEN a AND b")) + self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN) + + expression = annotate_types(parse_one("x IN (a, b, c, d)")) + self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN) + + def test_aggfunc_annotation(self): + schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}} + + tests = { + ("AVG", "cola"): exp.DataType.Type.DOUBLE, + ("SUM", "cola"): exp.DataType.Type.BIGINT, + ("SUM", "colb"): exp.DataType.Type.DOUBLE, + ("MIN", "cola"): exp.DataType.Type.SMALLINT, + ("MIN", "colb"): exp.DataType.Type.FLOAT, + ("MAX", "colc"): exp.DataType.Type.TEXT, + ("MAX", "cold"): exp.DataType.Type.DATE, + ("COUNT", "colb"): exp.DataType.Type.BIGINT, + ("STDDEV", "cola"): exp.DataType.Type.DOUBLE, + } + + for (func, col), target_type in tests.items(): + expression = annotate_types( + parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema + ) + self.assertEqual(expression.expressions[0].type.this, target_type) diff --git a/tests/test_schema.py b/tests/test_schema.py index cc0e3d1..f1e12a2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -151,31 +151,33 @@ class TestSchema(unittest.TestCase): def test_schema_get_column_type(self): schema = MappingSchema({"a": {"b": "varchar"}}) - self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR) + self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR) self.assertEqual( - schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")), + schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this, exp.DataType.Type.VARCHAR, ) self.assertEqual( - schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR + schema.get_column_type("a", exp.Column(this="b")).this, exp.DataType.Type.VARCHAR ) self.assertEqual( - schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR + schema.get_column_type(exp.Table(this="a"), "b").this, exp.DataType.Type.VARCHAR ) schema = MappingSchema({"a": {"b": {"c": "varchar"}}}) self.assertEqual( - schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")), + schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")).this, exp.DataType.Type.VARCHAR, ) self.assertEqual( - schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR + schema.get_column_type(exp.Table(this="b", db="a"), "c").this, exp.DataType.Type.VARCHAR ) schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}}) self.assertEqual( - schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")), + schema.get_column_type( + exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d") + ).this, exp.DataType.Type.VARCHAR, ) self.assertEqual( - schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"), + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this, exp.DataType.Type.VARCHAR, ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 1d1b966..1376849 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,6 +1,6 @@ import unittest -from sqlglot.tokens import Tokenizer +from sqlglot.tokens import Tokenizer, TokenType class TestTokens(unittest.TestCase): @@ -17,3 +17,48 @@ class TestTokens(unittest.TestCase): for sql, comment in sql_comment: self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) + + def test_jinja(self): + tokenizer = Tokenizer() + + tokens = tokenizer.tokenize( + """ + SELECT + {{ x }}, + {{- x -}}, + {% for x in y -%} + a {{+ b }} + {% endfor %}; + """ + ) + + tokens = [(token.token_type, token.text) for token in tokens] + + self.assertEqual( + tokens, + [ + (TokenType.SELECT, "SELECT"), + (TokenType.BLOCK_START, "{{"), + (TokenType.VAR, "x"), + (TokenType.BLOCK_END, "}}"), + (TokenType.COMMA, ","), + (TokenType.BLOCK_START, "{{-"), + (TokenType.VAR, "x"), + (TokenType.BLOCK_END, "-}}"), + (TokenType.COMMA, ","), + (TokenType.BLOCK_START, "{%"), + (TokenType.FOR, "for"), + (TokenType.VAR, "x"), + (TokenType.IN, "in"), + (TokenType.VAR, "y"), + (TokenType.BLOCK_END, "-%}"), + (TokenType.VAR, "a"), + (TokenType.BLOCK_START, "{{+"), + (TokenType.VAR, "b"), + (TokenType.BLOCK_END, "}}"), + (TokenType.BLOCK_START, "{%"), + (TokenType.VAR, "endfor"), + (TokenType.BLOCK_END, "%}"), + (TokenType.SEMICOLON, ";"), + ], + ) |