From bea2635be022e272ddac349f5e396ec901fc37e5 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 12 Dec 2022 16:42:38 +0100 Subject: Merging upstream version 10.2.6. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/dataframe.py | 2 +- sqlglot/dialects/bigquery.py | 33 ++- sqlglot/dialects/hive.py | 15 +- sqlglot/dialects/oracle.py | 1 + sqlglot/dialects/redshift.py | 10 + sqlglot/dialects/snowflake.py | 1 + sqlglot/executor/env.py | 9 +- sqlglot/executor/python.py | 4 +- sqlglot/expressions.py | 106 +++++++-- sqlglot/generator.py | 452 ++++++++++++++++++++---------------- sqlglot/helper.py | 8 + sqlglot/optimizer/annotate_types.py | 37 ++- sqlglot/optimizer/canonicalize.py | 25 +- sqlglot/optimizer/simplify.py | 235 +++++++++++++------ sqlglot/parser.py | 136 ++++++++--- sqlglot/schema.py | 22 +- sqlglot/tokens.py | 19 +- 18 files changed, 747 insertions(+), 370 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index b027ac7..3733b20 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -30,7 +30,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.1.3" +__version__ = "10.2.6" pretty = False diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 548c322..3c45741 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -317,7 +317,7 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.name + expression.alias_or_name: expression.type.sql("spark") for expression in select_expression.expressions }, ) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5b44912..6be68ac 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -110,17 +110,17 @@ class BigQuery(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "BEGIN": TokenType.COMMAND, + "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, "GEOGRAPHY": TokenType.GEOGRAPHY, - "INT64": TokenType.BIGINT, "FLOAT64": TokenType.DOUBLE, + "INT64": TokenType.BIGINT, + "NOT DETERMINISTIC": TokenType.VOLATILE, "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, "WINDOW": TokenType.WINDOW, - "NOT DETERMINISTIC": TokenType.VOLATILE, - "BEGIN": TokenType.COMMAND, - "BEGIN TRANSACTION": TokenType.BEGIN, } KEYWORDS.pop("DIV") @@ -131,6 +131,7 @@ class BigQuery(Dialect): "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), + "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), @@ -144,6 +145,7 @@ class BigQuery(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), } FUNCTION_PARSERS.pop("TRIM") @@ -161,7 +163,6 @@ class BigQuery(Dialect): class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, - exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), @@ -183,6 +184,7 @@ class BigQuery(Dialect): exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", + exp.RegexpLike: rename_func("REGEXP_CONTAINS"), } TYPE_MAPPING = { @@ -210,24 +212,31 @@ class BigQuery(Dialect): EXPLICIT_UNION = True - def transaction_sql(self, *_): + def array_sql(self, expression: exp.Array) -> str: + first_arg = seq_get(expression.expressions, 0) + if isinstance(first_arg, exp.Subqueryable): + return f"ARRAY{self.wrap(self.sql(first_arg))}" + + return inline_array_sql(self, expression) + + def transaction_sql(self, *_) -> str: return "BEGIN TRANSACTION" - def commit_sql(self, *_): + def commit_sql(self, *_) -> str: return "COMMIT TRANSACTION" - def rollback_sql(self, *_): + def rollback_sql(self, *_) -> str: return "ROLLBACK TRANSACTION" - def in_unnest_op(self, unnest): - return self.sql(unnest) + def in_unnest_op(self, expression: exp.Unnest) -> str: + return self.sql(expression) - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: if not expression.args.get("distinct", False): self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index cbb39c2..70c1c6c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -190,6 +190,7 @@ class Hive(Dialect): "ADD FILES": TokenType.COMMAND, "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, + "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } class Parser(parser.Parser): @@ -238,6 +239,13 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties( + expressions=self._parse_wrapped_csv(self._parse_property) + ), + } + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -297,6 +305,8 @@ class Hive(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", + exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", + exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), } @@ -308,12 +318,15 @@ class Hive(Dialect): exp.SchemaCommentProperty, exp.LocationProperty, exp.TableFormatProperty, + exp.RowFormatDelimitedProperty, + exp.RowFormatSerdeProperty, + exp.SerdeProperties, } def with_properties(self, properties): return self.properties( properties, - prefix="TBLPROPERTIES", + prefix=self.seg("TBLPROPERTIES"), ) def datatype_sql(self, expression): diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index ceaf9ba..f507513 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -98,6 +98,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "MINUS": TokenType.EXCEPT, "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index cd50979..55ed0a6 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,6 +1,7 @@ from __future__ import annotations from sqlglot import exp, transforms +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -13,12 +14,20 @@ class Redshift(Postgres): "HH": "%H", } + class Parser(Postgres.Parser): + FUNCTIONS = { + **Postgres.Parser.FUNCTIONS, # type: ignore + "DECODE": exp.Matches.from_arg_list, + "NVL": exp.Coalesce.from_arg_list, + } + class Tokenizer(Postgres.Tokenizer): ESCAPES = ["\\"] KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, # type: ignore "COPY": TokenType.COMMAND, + "ENCODE": TokenType.ENCODE, "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, @@ -50,4 +59,5 @@ class Redshift(Postgres): exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), + exp.Matches: rename_func("DECODE"), } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 46155ff..75dc9dc 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -198,6 +198,7 @@ class Snowflake(Dialect): "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, + "MINUS": TokenType.EXCEPT, "SAMPLE": TokenType.TABLE_SAMPLE, } diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index e6cfcdd..ad9397e 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -19,10 +19,13 @@ class reverse_key: return other.obj < self.obj -def filter_nulls(func): +def filter_nulls(func, empty_null=True): @wraps(func) def _func(values): - return func(v for v in values if v is not None) + filtered = tuple(v for v in values if v is not None) + if not filtered and empty_null: + return None + return func(filtered) return _func @@ -126,7 +129,7 @@ ENV = { # aggs "SUM": filter_nulls(sum), "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore - "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)), + "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), "MAX": filter_nulls(max), "MIN": filter_nulls(min), # scalar functions diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 908b80a..9f22c45 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -310,9 +310,9 @@ class PythonExecutor: if i == length - 1: context.set_range(start, end - 1) add_row() - elif step.limit > 0: + elif step.limit > 0 and not group_by: context.set_range(0, 0) - table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations)) + table.append(context.eval_tuple(aggregations)) context = self.context({step.name: table, **{name: table for name in context.tables}}) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 96b32f1..7249574 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comments") + __slots__ = ("args", "parent", "arg_key", "comments", "_type") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None - self.type = None self.comments = None + self._type: t.Optional[DataType] = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -122,6 +122,16 @@ class Expression(metaclass=_Expression): return "NULL" return self.alias or self.name + @property + def type(self) -> t.Optional[DataType]: + return self._type + + @type.setter + def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None: + if dtype and not isinstance(dtype, DataType): + dtype = DataType.build(dtype) + self._type = dtype # type: ignore + def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) copy.comments = self.comments @@ -348,7 +358,7 @@ class Expression(metaclass=_Expression): indent += "".join([" "] * level) left = f"({self.key.upper()} " - args = { + args: t.Dict[str, t.Any] = { k: ", ".join( v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) for v in ensure_collection(vs) @@ -612,6 +622,7 @@ class Create(Expression): "properties": False, "temporary": False, "transient": False, + "external": False, "replace": False, "unique": False, "materialized": False, @@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind): pass +class EncodeColumnConstraint(ColumnConstraintKind): + pass + + class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = {"this": True, "expression": False} class NotNullColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"allow_null": False} class PrimaryKeyColumnConstraint(ColumnConstraintKind): @@ -766,7 +781,7 @@ class Constraint(Expression): class Delete(Expression): - arg_types = {"with": False, "this": True, "using": False, "where": False} + arg_types = {"with": False, "this": False, "using": False, "where": False} class Drop(Expression): @@ -850,7 +865,7 @@ class Insert(Expression): arg_types = { "with": False, "this": True, - "expression": True, + "expression": False, "overwrite": False, "exists": False, "partition": False, @@ -1125,6 +1140,27 @@ class VolatilityProperty(Property): arg_types = {"this": True} +class RowFormatDelimitedProperty(Property): + # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml + arg_types = { + "fields": False, + "escaped": False, + "collection_items": False, + "map_keys": False, + "lines": False, + "null": False, + "serde": False, + } + + +class RowFormatSerdeProperty(Property): + arg_types = {"this": True} + + +class SerdeProperties(Property): + arg_types = {"expressions": True} + + class Properties(Expression): arg_types = {"expressions": True} @@ -1169,18 +1205,6 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} -class RowFormat(Expression): - # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml - arg_types = { - "fields": False, - "escaped": False, - "collection_items": False, - "map_keys": False, - "lines": False, - "null": False, - } - - class Tuple(Expression): arg_types = {"expressions": False} @@ -1208,6 +1232,9 @@ class Subqueryable(Unionable): alias=TableAlias(this=to_identifier(alias)), ) + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + raise NotImplementedError + @property def ctes(self): with_ = self.args.get("with") @@ -1320,6 +1347,32 @@ class Union(Subqueryable): **QUERY_MODIFIERS, } + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + """ + Set the LIMIT expression. + + Example: + >>> select("1").union(select("1")).limit(1).sql() + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' + + Args: + expression (str | int | Expression): the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: The limited subqueryable. + """ + return ( + select("*") + .from_(self.subquery(alias="_l_0", copy=copy)) + .limit(expression, dialect=dialect, copy=False, **opts) + ) + @property def named_selects(self): return self.this.unnest().named_selects @@ -1356,7 +1409,7 @@ class Unnest(UDTF): class Update(Expression): arg_types = { "with": False, - "this": True, + "this": False, "expressions": True, "from": False, "where": False, @@ -2057,15 +2110,20 @@ class DataType(Expression): Type.TEXT, } - NUMERIC_TYPES = { + INTEGER_TYPES = { Type.INT, Type.TINYINT, Type.SMALLINT, Type.BIGINT, + } + + FLOAT_TYPES = { Type.FLOAT, Type.DOUBLE, } + NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES} + TEMPORAL_TYPES = { Type.TIMESTAMP, Type.TIMESTAMPTZ, @@ -2968,6 +3026,14 @@ class Use(Expression): pass +class Merge(Expression): + arg_types = {"this": True, "using": True, "on": True, "expressions": True} + + +class When(Func): + arg_types = {"this": True, "then": True} + + def _norm_args(expression): args = {} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 47774fc..beffb91 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -189,12 +189,12 @@ class Generator: self._max_text_width = max_text_width self._comments = comments - def generate(self, expression): + def generate(self, expression: t.Optional[exp.Expression]) -> str: """ Generates a SQL string by interpreting the given syntax tree. Args - expression (Expression): the syntax tree. + expression: the syntax tree. Returns the SQL string. @@ -213,23 +213,23 @@ class Generator: return sql - def unsupported(self, message): + def unsupported(self, message: str) -> None: if self.unsupported_level == ErrorLevel.IMMEDIATE: raise UnsupportedError(message) self.unsupported_messages.append(message) - def sep(self, sep=" "): + def sep(self, sep: str = " ") -> str: return f"{sep.strip()}\n" if self.pretty else sep - def seg(self, sql, sep=" "): + def seg(self, sql: str, sep: str = " ") -> str: return f"{self.sep(sep)}{sql}" - def pad_comment(self, comment): + def pad_comment(self, comment: str) -> str: comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment return comment - def maybe_comment(self, sql, expression): + def maybe_comment(self, sql: str, expression: exp.Expression) -> str: comments = expression.comments if self._comments else None if not comments: @@ -243,7 +243,7 @@ class Generator: return f"{sql} {comments}" - def wrap(self, expression): + def wrap(self, expression: exp.Expression | str) -> str: this_sql = self.indent( self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) @@ -253,21 +253,28 @@ class Generator: ) return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - def no_identify(self, func): + def no_identify(self, func: t.Callable[[], str]) -> str: original = self.identify self.identify = False result = func() self.identify = original return result - def normalize_func(self, name): + def normalize_func(self, name: str) -> str: if self.normalize_functions == "upper": return name.upper() if self.normalize_functions == "lower": return name.lower() return name - def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False): + def indent( + self, + sql: str, + level: int = 0, + pad: t.Optional[int] = None, + skip_first: bool = False, + skip_last: bool = False, + ) -> str: if not self.pretty: return sql @@ -281,7 +288,12 @@ class Generator: for i, line in enumerate(lines) ) - def sql(self, expression, key=None, comment=True): + def sql( + self, + expression: t.Optional[str | exp.Expression], + key: t.Optional[str] = None, + comment: bool = True, + ) -> str: if not expression: return "" @@ -313,12 +325,12 @@ class Generator: return self.maybe_comment(sql, expression) if self._comments and comment else sql - def uncache_sql(self, expression): + def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") exists_sql = " IF EXISTS" if expression.args.get("exists") else "" return f"UNCACHE TABLE{exists_sql} {table}" - def cache_sql(self, expression): + def cache_sql(self, expression: exp.Cache) -> str: lazy = " LAZY" if expression.args.get("lazy") else "" table = self.sql(expression, "this") options = expression.args.get("options") @@ -328,13 +340,13 @@ class Generator: sql = f"CACHE{lazy} TABLE {table}{options}{sql}" return self.prepend_ctes(expression, sql) - def characterset_sql(self, expression): + def characterset_sql(self, expression: exp.CharacterSet) -> str: if isinstance(expression.parent, exp.Cast): return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" default = "DEFAULT " if expression.args.get("default") else "" return f"{default}CHARACTER SET={self.sql(expression, 'this')}" - def column_sql(self, expression): + def column_sql(self, expression: exp.Column) -> str: return ".".join( part for part in [ @@ -345,7 +357,7 @@ class Generator: if part ) - def columndef_sql(self, expression): + def columndef_sql(self, expression: exp.ColumnDef) -> str: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) @@ -354,46 +366,52 @@ class Generator: return f"{column} {kind}" return f"{column} {kind} {constraints}" - def columnconstraint_sql(self, expression): + def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") kind_sql = self.sql(expression, "kind") return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql - def autoincrementcolumnconstraint_sql(self, _): + def autoincrementcolumnconstraint_sql(self, _) -> str: return self.token_sql(TokenType.AUTO_INCREMENT) - def checkcolumnconstraint_sql(self, expression): + def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: this = self.sql(expression, "this") return f"CHECK ({this})" - def commentcolumnconstraint_sql(self, expression): + def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str: comment = self.sql(expression, "this") return f"COMMENT {comment}" - def collatecolumnconstraint_sql(self, expression): + def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str: collate = self.sql(expression, "this") return f"COLLATE {collate}" - def defaultcolumnconstraint_sql(self, expression): + def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str: + encode = self.sql(expression, "this") + return f"ENCODE {encode}" + + def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str: default = self.sql(expression, "this") return f"DEFAULT {default}" - def generatedasidentitycolumnconstraint_sql(self, expression): + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" - def notnullcolumnconstraint_sql(self, _): - return "NOT NULL" + def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: + return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" - def primarykeycolumnconstraint_sql(self, expression): + def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str: desc = expression.args.get("desc") if desc is not None: return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" return f"PRIMARY KEY" - def uniquecolumnconstraint_sql(self, _): + def uniquecolumnconstraint_sql(self, _) -> str: return "UNIQUE" - def create_sql(self, expression): + def create_sql(self, expression: exp.Create) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() expression_sql = self.sql(expression, "expression") @@ -402,47 +420,58 @@ class Generator: transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" ) + external = " EXTERNAL" if expression.args.get("external") else "" replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") - expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" + modifiers = "".join( + ( + replace, + temporary, + transient, + external, + unique, + materialized, + ) + ) + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}" return self.prepend_ctes(expression, expression_sql) - def describe_sql(self, expression): + def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" - def prepend_ctes(self, expression, sql): + def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: with_ = self.sql(expression, "with") if with_: sql = f"{with_}{self.sep()}{sql}" return sql - def with_sql(self, expression): + def with_sql(self, expression: exp.With) -> str: sql = self.expressions(expression, flat=True) recursive = "RECURSIVE " if expression.args.get("recursive") else "" return f"WITH {recursive}{sql}" - def cte_sql(self, expression): + def cte_sql(self, expression: exp.CTE) -> str: alias = self.sql(expression, "alias") return f"{alias} AS {self.wrap(expression)}" - def tablealias_sql(self, expression): + def tablealias_sql(self, expression: exp.TableAlias) -> str: alias = self.sql(expression, "this") columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" return f"{alias}{columns}" - def bitstring_sql(self, expression): + def bitstring_sql(self, expression: exp.BitString) -> str: return self.sql(expression, "this") - def hexstring_sql(self, expression): + def hexstring_sql(self, expression: exp.HexString) -> str: return self.sql(expression, "this") - def datatype_sql(self, expression): + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) nested = "" @@ -455,13 +484,13 @@ class Generator: ) return f"{type_sql}{nested}" - def directory_sql(self, expression): + def directory_sql(self, expression: exp.Directory) -> str: local = "LOCAL " if expression.args.get("local") else "" row_format = self.sql(expression, "row_format") row_format = f" {row_format}" if row_format else "" return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" - def delete_sql(self, expression): + def delete_sql(self, expression: exp.Delete) -> str: this = self.sql(expression, "this") using_sql = ( f" USING {self.expressions(expression, 'using', sep=', USING ')}" @@ -472,7 +501,7 @@ class Generator: sql = f"DELETE FROM {this}{using_sql}{where_sql}" return self.prepend_ctes(expression, sql) - def drop_sql(self, expression): + def drop_sql(self, expression: exp.Drop) -> str: this = self.sql(expression, "this") kind = expression.args["kind"] exists_sql = " IF EXISTS " if expression.args.get("exists") else " " @@ -481,46 +510,46 @@ class Generator: cascade = " CASCADE" if expression.args.get("cascade") else "" return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}" - def except_sql(self, expression): + def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.except_op(expression)), ) - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}" - def fetch_sql(self, expression): + def fetch_sql(self, expression: exp.Fetch) -> str: direction = expression.args.get("direction") direction = f" {direction.upper()}" if direction else "" count = expression.args.get("count") count = f" {count}" if count else "" return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" - def filter_sql(self, expression): + def filter_sql(self, expression: exp.Filter) -> str: this = self.sql(expression, "this") where = self.sql(expression, "expression")[1:] # where has a leading space return f"{this} FILTER({where})" - def hint_sql(self, expression): + def hint_sql(self, expression: exp.Hint) -> str: if self.sql(expression, "this"): self.unsupported("Hints are not supported") return "" - def index_sql(self, expression): + def index_sql(self, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") columns = self.sql(expression, "columns") return f"{this} ON {table} {columns}" - def identifier_sql(self, expression): + def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name text = text.lower() if self.normalize else text if expression.args.get("quoted") or self.identify: text = f"{self.identifier_start}{text}{self.identifier_end}" return text - def partition_sql(self, expression): + def partition_sql(self, expression: exp.Partition) -> str: keys = csv( *[ f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name @@ -529,7 +558,7 @@ class Generator: ) return f"PARTITION({keys})" - def properties_sql(self, expression): + def properties_sql(self, expression: exp.Properties) -> str: root_properties = [] with_properties = [] @@ -544,21 +573,21 @@ class Generator: exp.Properties(expressions=root_properties) ) + self.with_properties(exp.Properties(expressions=with_properties)) - def root_properties(self, properties): + def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" - def properties(self, properties, prefix="", sep=", "): + def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" + return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}" return "" - def with_properties(self, properties): - return self.properties(properties, prefix="WITH") + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, prefix=self.seg("WITH")) - def property_sql(self, expression): + def property_sql(self, expression: exp.Property) -> str: property_cls = expression.__class__ if property_cls == exp.Property: return f"{expression.name}={self.sql(expression, 'value')}" @@ -569,12 +598,12 @@ class Generator: return f"{property_name}={self.sql(expression, 'this')}" - def likeproperty_sql(self, expression): + def likeproperty_sql(self, expression: exp.LikeProperty) -> str: options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) options = f" {options}" if options else "" return f"LIKE {self.sql(expression, 'this')}{options}" - def insert_sql(self, expression): + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") if isinstance(expression.this, exp.Directory): @@ -592,19 +621,19 @@ class Generator: sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) - def intersect_sql(self, expression): + def intersect_sql(self, expression: exp.Intersect) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.intersect_op(expression)), ) - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}" - def introducer_sql(self, expression): + def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def rowformat_sql(self, expression): + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" escaped = expression.args.get("escaped") @@ -619,7 +648,7 @@ class Generator: null = f" NULL DEFINED AS {null}" if null else "" return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" - def table_sql(self, expression, sep=" AS "): + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: table = ".".join( part for part in [ @@ -642,7 +671,7 @@ class Generator: return f"{table}{alias}{laterals}{joins}{pivots}" - def tablesample_sql(self, expression): + def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: this = self.sql(expression.this, "this") alias = f" AS {self.sql(expression.this, 'alias')}" @@ -665,7 +694,7 @@ class Generator: seed = f" SEED ({seed})" if seed else "" return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}" - def pivot_sql(self, expression): + def pivot_sql(self, expression: exp.Pivot) -> str: this = self.sql(expression, "this") unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" @@ -673,10 +702,10 @@ class Generator: field = self.sql(expression, "field") return f"{this} {direction}({expressions} FOR {field})" - def tuple_sql(self, expression): + def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" - def update_sql(self, expression): + def update_sql(self, expression: exp.Update) -> str: this = self.sql(expression, "this") set_sql = self.expressions(expression, flat=True) from_sql = self.sql(expression, "from") @@ -684,7 +713,7 @@ class Generator: sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}" return self.prepend_ctes(expression, sql) - def values_sql(self, expression): + def values_sql(self, expression: exp.Values) -> str: alias = self.sql(expression, "alias") args = self.expressions(expression) if not alias: @@ -694,19 +723,19 @@ class Generator: return f"(VALUES{self.seg('')}{args}){alias}" return f"VALUES{self.seg('')}{args}{alias}" - def var_sql(self, expression): + def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") - def into_sql(self, expression): + def into_sql(self, expression: exp.Into) -> str: temporary = " TEMPORARY" if expression.args.get("temporary") else "" unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" - def from_sql(self, expression): + def from_sql(self, expression: exp.From) -> str: expressions = self.expressions(expression, flat=True) return f"{self.seg('FROM')} {expressions}" - def group_sql(self, expression): + def group_sql(self, expression: exp.Group) -> str: group_by = self.op_expressions("GROUP BY", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = ( @@ -718,11 +747,11 @@ class Generator: rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" return f"{group_by}{grouping_sets}{cube}{rollup}" - def having_sql(self, expression): + def having_sql(self, expression: exp.Having) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('HAVING')}{self.sep()}{this}" - def join_sql(self, expression): + def join_sql(self, expression: exp.Join) -> str: op_sql = self.seg( " ".join( op @@ -753,12 +782,12 @@ class Generator: this_sql = self.sql(expression, "this") return f"{expression_sql}{op_sql} {this_sql}{on_sql}" - def lambda_sql(self, expression, arrow_sep="->"): + def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) args = f"({args})" if len(args.split(",")) > 1 else args return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}") - def lateral_sql(self, expression): + def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") if isinstance(expression.this, exp.Subquery): @@ -776,15 +805,15 @@ class Generator: return f"LATERAL {this}{table}{columns}" - def limit_sql(self, expression): + def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" - def offset_sql(self, expression): + def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" - def literal_sql(self, expression): + def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" if expression.is_string: if self._replace_backslash: @@ -793,7 +822,7 @@ class Generator: text = f"{self.quote_start}{text}{self.quote_end}" return text - def loaddata_sql(self, expression): + def loaddata_sql(self, expression: exp.LoadData) -> str: local = " LOCAL" if expression.args.get("local") else "" inpath = f" INPATH {self.sql(expression, 'inpath')}" overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" @@ -806,27 +835,27 @@ class Generator: serde = f" SERDE {serde}" if serde else "" return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" - def null_sql(self, *_): + def null_sql(self, *_) -> str: return "NULL" - def boolean_sql(self, expression): + def boolean_sql(self, expression: exp.Boolean) -> str: return "TRUE" if expression.this else "FALSE" - def order_sql(self, expression, flat=False): + def order_sql(self, expression: exp.Order, flat: bool = False) -> str: this = self.sql(expression, "this") this = f"{this} " if this else this - return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) + return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore - def cluster_sql(self, expression): + def cluster_sql(self, expression: exp.Cluster) -> str: return self.op_expressions("CLUSTER BY", expression) - def distribute_sql(self, expression): + def distribute_sql(self, expression: exp.Distribute) -> str: return self.op_expressions("DISTRIBUTE BY", expression) - def sort_sql(self, expression): + def sort_sql(self, expression: exp.Sort) -> str: return self.op_expressions("SORT BY", expression) - def ordered_sql(self, expression): + def ordered_sql(self, expression: exp.Ordered) -> str: desc = expression.args.get("desc") asc = not desc @@ -857,7 +886,7 @@ class Generator: return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" - def query_modifiers(self, expression, *sqls): + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("laterals", [])], @@ -876,7 +905,7 @@ class Generator: sep="", ) - def select_sql(self, expression): + def select_sql(self, expression: exp.Select) -> str: hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" @@ -890,36 +919,36 @@ class Generator: ) return self.prepend_ctes(expression, sql) - def schema_sql(self, expression): + def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") this = f"{this} " if this else "" sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" return f"{this}{sql}" - def star_sql(self, expression): + def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" return f"*{except_}{replace}" - def structkwarg_sql(self, expression): + def structkwarg_sql(self, expression: exp.StructKwarg) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def parameter_sql(self, expression): + def parameter_sql(self, expression: exp.Parameter) -> str: return f"@{self.sql(expression, 'this')}" - def sessionparameter_sql(self, expression): + def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: this = self.sql(expression, "this") kind = expression.text("kind") if kind: kind = f"{kind}." return f"@@{kind}{this}" - def placeholder_sql(self, expression): + def placeholder_sql(self, expression: exp.Placeholder) -> str: return f":{expression.name}" if expression.name else "?" - def subquery_sql(self, expression): + def subquery_sql(self, expression: exp.Subquery) -> str: alias = self.sql(expression, "alias") sql = self.query_modifiers( @@ -931,22 +960,22 @@ class Generator: return self.prepend_ctes(expression, sql) - def qualify_sql(self, expression): + def qualify_sql(self, expression: exp.Qualify) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('QUALIFY')}{self.sep()}{this}" - def union_sql(self, expression): + def union_sql(self, expression: exp.Union) -> str: return self.prepend_ctes( expression, self.set_operation(expression, self.union_op(expression)), ) - def union_op(self, expression): + def union_op(self, expression: exp.Union) -> str: kind = " DISTINCT" if self.EXPLICIT_UNION else "" kind = kind if expression.args.get("distinct") else " ALL" return f"UNION{kind}" - def unnest_sql(self, expression): + def unnest_sql(self, expression: exp.Unnest) -> str: args = self.expressions(expression, flat=True) alias = expression.args.get("alias") if alias and self.unnest_column_only: @@ -958,11 +987,11 @@ class Generator: ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" return f"UNNEST({args}){ordinality}{alias}" - def where_sql(self, expression): + def where_sql(self, expression: exp.Where) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('WHERE')}{self.sep()}{this}" - def window_sql(self, expression): + def window_sql(self, expression: exp.Window) -> str: this = self.sql(expression, "this") partition = self.expressions(expression, key="partition_by", flat=True) @@ -988,7 +1017,7 @@ class Generator: return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" - def window_spec_sql(self, expression): + def window_spec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") end = ( @@ -997,33 +1026,33 @@ class Generator: ) return f"{kind} BETWEEN {start} AND {end}" - def withingroup_sql(self, expression): + def withingroup_sql(self, expression: exp.WithinGroup) -> str: this = self.sql(expression, "this") - expression = self.sql(expression, "expression")[1:] # order has a leading space - return f"{this} WITHIN GROUP ({expression})" + expression_sql = self.sql(expression, "expression")[1:] # order has a leading space + return f"{this} WITHIN GROUP ({expression_sql})" - def between_sql(self, expression): + def between_sql(self, expression: exp.Between) -> str: this = self.sql(expression, "this") low = self.sql(expression, "low") high = self.sql(expression, "high") return f"{this} BETWEEN {low} AND {high}" - def bracket_sql(self, expression): + def bracket_sql(self, expression: exp.Bracket) -> str: expressions = apply_index_offset(expression.expressions, self.index_offset) - expressions = ", ".join(self.sql(e) for e in expressions) + expressions_sql = ", ".join(self.sql(e) for e in expressions) - return f"{self.sql(expression, 'this')}[{expressions}]" + return f"{self.sql(expression, 'this')}[{expressions_sql}]" - def all_sql(self, expression): + def all_sql(self, expression: exp.All) -> str: return f"ALL {self.wrap(expression)}" - def any_sql(self, expression): + def any_sql(self, expression: exp.Any) -> str: return f"ANY {self.wrap(expression)}" - def exists_sql(self, expression): + def exists_sql(self, expression: exp.Exists) -> str: return f"EXISTS{self.wrap(expression)}" - def case_sql(self, expression): + def case_sql(self, expression: exp.Case) -> str: this = self.sql(expression, "this") statements = [f"CASE {this}" if this else "CASE"] @@ -1043,17 +1072,17 @@ class Generator: return " ".join(statements) - def constraint_sql(self, expression): + def constraint_sql(self, expression: exp.Constraint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"CONSTRAINT {this} {expressions}" - def extract_sql(self, expression): + def extract_sql(self, expression: exp.Extract) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" - def trim_sql(self, expression): + def trim_sql(self, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") @@ -1064,16 +1093,16 @@ class Generator: else: return f"TRIM({target})" - def concat_sql(self, expression): + def concat_sql(self, expression: exp.Concat) -> str: if len(expression.expressions) == 1: return self.sql(expression.expressions[0]) return self.function_fallback_sql(expression) - def check_sql(self, expression): + def check_sql(self, expression: exp.Check) -> str: this = self.sql(expression, key="this") return f"CHECK ({this})" - def foreignkey_sql(self, expression): + def foreignkey_sql(self, expression: exp.ForeignKey) -> str: expressions = self.expressions(expression, flat=True) reference = self.sql(expression, "reference") reference = f" {reference}" if reference else "" @@ -1083,16 +1112,16 @@ class Generator: update = f" ON UPDATE {update}" if update else "" return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" - def unique_sql(self, expression): + def unique_sql(self, expression: exp.Unique) -> str: columns = self.expressions(expression, key="expressions") return f"UNIQUE ({columns})" - def if_sql(self, expression): + def if_sql(self, expression: exp.If) -> str: return self.case_sql( exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) ) - def in_sql(self, expression): + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") field = expression.args.get("field") @@ -1106,24 +1135,24 @@ class Generator: in_sql = f"({self.expressions(expression, flat=True)})" return f"{self.sql(expression, 'this')} IN {in_sql}" - def in_unnest_op(self, unnest): + def in_unnest_op(self, unnest: exp.Unnest) -> str: return f"(SELECT {self.sql(unnest)})" - def interval_sql(self, expression): + def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" return f"INTERVAL {self.sql(expression, 'this')}{unit}" - def reference_sql(self, expression): + def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"REFERENCES {this}({expressions})" - def anonymous_sql(self, expression): + def anonymous_sql(self, expression: exp.Anonymous) -> str: args = self.format_args(*expression.expressions) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" - def paren_sql(self, expression): + def paren_sql(self, expression: exp.Paren) -> str: if isinstance(expression.unnest(), exp.Select): sql = self.wrap(expression) else: @@ -1132,35 +1161,35 @@ class Generator: return self.prepend_ctes(expression, sql) - def neg_sql(self, expression): + def neg_sql(self, expression: exp.Neg) -> str: # This makes sure we don't convert "- - 5" to "--5", which is a comment this_sql = self.sql(expression, "this") sep = " " if this_sql[0] == "-" else "" return f"-{sep}{this_sql}" - def not_sql(self, expression): + def not_sql(self, expression: exp.Not) -> str: return f"NOT {self.sql(expression, 'this')}" - def alias_sql(self, expression): + def alias_sql(self, expression: exp.Alias) -> str: to_sql = self.sql(expression, "alias") to_sql = f" AS {to_sql}" if to_sql else "" return f"{self.sql(expression, 'this')}{to_sql}" - def aliases_sql(self, expression): + def aliases_sql(self, expression: exp.Aliases) -> str: return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" - def attimezone_sql(self, expression): + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: this = self.sql(expression, "this") zone = self.sql(expression, "zone") return f"{this} AT TIME ZONE {zone}" - def add_sql(self, expression): + def add_sql(self, expression: exp.Add) -> str: return self.binary(expression, "+") - def and_sql(self, expression): + def and_sql(self, expression: exp.And) -> str: return self.connector_sql(expression, "AND") - def connector_sql(self, expression, op): + def connector_sql(self, expression: exp.Connector, op: str) -> str: if not self.pretty: return self.binary(expression, op) @@ -1168,53 +1197,53 @@ class Generator: sep = "\n" if self.text_width(sqls) > self._max_text_width else " " return f"{sep}{op} ".join(sqls) - def bitwiseand_sql(self, expression): + def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: return self.binary(expression, "&") - def bitwiseleftshift_sql(self, expression): + def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str: return self.binary(expression, "<<") - def bitwisenot_sql(self, expression): + def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str: return f"~{self.sql(expression, 'this')}" - def bitwiseor_sql(self, expression): + def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str: return self.binary(expression, "|") - def bitwiserightshift_sql(self, expression): + def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str: return self.binary(expression, ">>") - def bitwisexor_sql(self, expression): + def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str: return self.binary(expression, "^") - def cast_sql(self, expression): + def cast_sql(self, expression: exp.Cast) -> str: return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - def currentdate_sql(self, expression): + def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" - def collate_sql(self, expression): + def collate_sql(self, expression: exp.Collate) -> str: return self.binary(expression, "COLLATE") - def command_sql(self, expression): + def command_sql(self, expression: exp.Command) -> str: return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" - def transaction_sql(self, *_): + def transaction_sql(self, *_) -> str: return "BEGIN" - def commit_sql(self, expression): + def commit_sql(self, expression: exp.Commit) -> str: chain = expression.args.get("chain") if chain is not None: chain = " AND CHAIN" if chain else " AND NO CHAIN" return f"COMMIT{chain or ''}" - def rollback_sql(self, expression): + def rollback_sql(self, expression: exp.Rollback) -> str: savepoint = expression.args.get("savepoint") savepoint = f" TO {savepoint}" if savepoint else "" return f"ROLLBACK{savepoint}" - def distinct_sql(self, expression): + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1222,13 +1251,13 @@ class Generator: on = f" ON {on}" if on else "" return f"DISTINCT{this}{on}" - def ignorenulls_sql(self, expression): + def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: return f"{self.sql(expression, 'this')} IGNORE NULLS" - def respectnulls_sql(self, expression): + def respectnulls_sql(self, expression: exp.RespectNulls) -> str: return f"{self.sql(expression, 'this')} RESPECT NULLS" - def intdiv_sql(self, expression): + def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( this=exp.Div(this=expression.this, expression=expression.expression), @@ -1236,79 +1265,79 @@ class Generator: ) ) - def dpipe_sql(self, expression): + def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") - def div_sql(self, expression): + def div_sql(self, expression: exp.Div) -> str: return self.binary(expression, "/") - def distance_sql(self, expression): + def distance_sql(self, expression: exp.Distance) -> str: return self.binary(expression, "<->") - def dot_sql(self, expression): + def dot_sql(self, expression: exp.Dot) -> str: return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" - def eq_sql(self, expression): + def eq_sql(self, expression: exp.EQ) -> str: return self.binary(expression, "=") - def escape_sql(self, expression): + def escape_sql(self, expression: exp.Escape) -> str: return self.binary(expression, "ESCAPE") - def gt_sql(self, expression): + def gt_sql(self, expression: exp.GT) -> str: return self.binary(expression, ">") - def gte_sql(self, expression): + def gte_sql(self, expression: exp.GTE) -> str: return self.binary(expression, ">=") - def ilike_sql(self, expression): + def ilike_sql(self, expression: exp.ILike) -> str: return self.binary(expression, "ILIKE") - def is_sql(self, expression): + def is_sql(self, expression: exp.Is) -> str: return self.binary(expression, "IS") - def like_sql(self, expression): + def like_sql(self, expression: exp.Like) -> str: return self.binary(expression, "LIKE") - def similarto_sql(self, expression): + def similarto_sql(self, expression: exp.SimilarTo) -> str: return self.binary(expression, "SIMILAR TO") - def lt_sql(self, expression): + def lt_sql(self, expression: exp.LT) -> str: return self.binary(expression, "<") - def lte_sql(self, expression): + def lte_sql(self, expression: exp.LTE) -> str: return self.binary(expression, "<=") - def mod_sql(self, expression): + def mod_sql(self, expression: exp.Mod) -> str: return self.binary(expression, "%") - def mul_sql(self, expression): + def mul_sql(self, expression: exp.Mul) -> str: return self.binary(expression, "*") - def neq_sql(self, expression): + def neq_sql(self, expression: exp.NEQ) -> str: return self.binary(expression, "<>") - def nullsafeeq_sql(self, expression): + def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str: return self.binary(expression, "IS NOT DISTINCT FROM") - def nullsafeneq_sql(self, expression): + def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str: return self.binary(expression, "IS DISTINCT FROM") - def or_sql(self, expression): + def or_sql(self, expression: exp.Or) -> str: return self.connector_sql(expression, "OR") - def sub_sql(self, expression): + def sub_sql(self, expression: exp.Sub) -> str: return self.binary(expression, "-") - def trycast_sql(self, expression): + def trycast_sql(self, expression: exp.TryCast) -> str: return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" - def use_sql(self, expression): + def use_sql(self, expression: exp.Use) -> str: return f"USE {self.sql(expression, 'this')}" - def binary(self, expression, op): + def binary(self, expression: exp.Binary, op: str) -> str: return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" - def function_fallback_sql(self, expression): + def function_fallback_sql(self, expression: exp.Func) -> str: args = [] for arg_value in expression.args.values(): if isinstance(arg_value, list): @@ -1319,19 +1348,26 @@ class Generator: return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" - def format_args(self, *args): - args = tuple(self.sql(arg) for arg in args if arg is not None) - if self.pretty and self.text_width(args) > self._max_text_width: - return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True) - return ", ".join(args) + def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: + arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) + if self.pretty and self.text_width(arg_sqls) > self._max_text_width: + return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) + return ", ".join(arg_sqls) - def text_width(self, args): + def text_width(self, args: t.Iterable) -> int: return sum(len(arg) for arg in args) - def format_time(self, expression): + def format_time(self, expression: exp.Expression) -> t.Optional[str]: return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) - def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): + def expressions( + self, + expression: exp.Expression, + key: t.Optional[str] = None, + flat: bool = False, + indent: bool = True, + sep: str = ", ", + ) -> str: expressions = expression.args.get(key or "expressions") if not expressions: @@ -1359,45 +1395,67 @@ class Generator: else: result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") - result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) - return self.indent(result_sqls, skip_first=False) if indent else result_sqls + result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) + return self.indent(result_sql, skip_first=False) if indent else result_sql - def op_expressions(self, op, expression, flat=False): + def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: expressions_sql = self.expressions(expression, flat=flat) if flat: return f"{op} {expressions_sql}" return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" - def naked_property(self, expression): + def naked_property(self, expression: exp.Property) -> str: property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) if not property_name: self.unsupported(f"Unsupported property {expression.__class__.__name__}") return f"{property_name} {self.sql(expression, 'this')}" - def set_operation(self, expression, op): + def set_operation(self, expression: exp.Expression, op: str) -> str: this = self.sql(expression, "this") op = self.seg(op) return self.query_modifiers( expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" ) - def token_sql(self, token_type): + def token_sql(self, token_type: TokenType) -> str: return self.TOKEN_MAPPING.get(token_type, token_type.name) - def userdefinedfunction_sql(self, expression): + def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") expressions = self.no_identify(lambda: self.expressions(expression)) return f"{this}({expressions})" - def userdefinedfunctionkwarg_sql(self, expression): + def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind") return f"{this} {kind}" - def joinhint_sql(self, expression): + def joinhint_sql(self, expression: exp.JoinHint) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) return f"{this}({expressions})" - def kwarg_sql(self, expression): + def kwarg_sql(self, expression: exp.Kwarg) -> str: return self.binary(expression, "=>") + + def when_sql(self, expression: exp.When) -> str: + this = self.sql(expression, "this") + then_expression = expression.args.get("then") + if isinstance(then_expression, exp.Insert): + then = f"INSERT {self.sql(then_expression, 'this')}" + if "expression" in then_expression.args: + then += f" VALUES {self.sql(then_expression, 'expression')}" + elif isinstance(then_expression, exp.Update): + if isinstance(then_expression.args.get("expressions"), exp.Star): + then = f"UPDATE {self.sql(then_expression, 'expressions')}" + else: + then = f"UPDATE SET {self.expressions(then_expression, flat=True)}" + else: + then = self.sql(then_expression) + return f"WHEN {this} THEN {then}" + + def merge_sql(self, expression: exp.Merge) -> str: + this = self.sql(expression, "this") + using = f"USING {self.sql(expression, 'using')}" + on = f"ON {self.sql(expression, 'on')}" + return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 8c5808d..ed37e6c 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int: except StopIteration: # d.values() returns an empty sequence return 1 + + +def first(it: t.Iterable[T]) -> T: + """Returns the first element from an iterable. + + Useful for sets. + """ + return next(i for i in it) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 191ea52..be17f15 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) - >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" + >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" Args: @@ -41,9 +41,12 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), exp.Alias: lambda self, expr: self._annotate_unary(expr), + exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Literal: lambda self, expr: self._annotate_literal(expr), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), @@ -52,6 +55,9 @@ class TypeAnnotator: expr, exp.DataType.Type.BIGINT ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), @@ -263,10 +269,10 @@ class TypeAnnotator: } # First annotate the current scope's column references for col in scope.columns: - source = scope.sources[col.table] + source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - else: + elif source: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -280,6 +286,7 @@ class TypeAnnotator: return expression # We've already inferred the expression's type annotator = self.annotators.get(expression.__class__) + return ( annotator(self, expression) if annotator @@ -295,18 +302,23 @@ class TypeAnnotator: def _maybe_coerce(self, type1, type2): # We propagate the NULL / UNKNOWN types upwards if found + if isinstance(type1, exp.DataType): + type1 = type1.this + if isinstance(type2, exp.DataType): + type2 = type2.this + if exp.DataType.Type.NULL in (type1, type2): return exp.DataType.Type.NULL if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to[type1] else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - left_type = expression.left.type - right_type = expression.right.type + left_type = expression.left.type.this + right_type = expression.right.type.this if isinstance(expression, (exp.And, exp.Or)): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -348,7 +360,7 @@ class TypeAnnotator: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args): + def _annotate_by_args(self, expression, *args, promote=False): self._annotate_args(expression) expressions = [] for arg in args: @@ -360,4 +372,11 @@ class TypeAnnotator: last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) expression.type = last_datatype or exp.DataType.Type.UNKNOWN + + if promote: + if expression.type.this in exp.DataType.INTEGER_TYPES: + expression.type = exp.DataType.Type.BIGINT + elif expression.type.this in exp.DataType.FLOAT_TYPES: + expression.type = exp.DataType.Type.DOUBLE + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 9b3d98a..33529a5 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression: The expression to canonicalize. """ exp.replace_children(expression, canonicalize) + expression = add_text_to_concat(expression) expression = coerce_type(expression) + expression = remove_redundant_casts(expression) + return expression def add_text_to_concat(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: + if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: node = exp.Concat(this=node.this, expression=node.expression) return node @@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression: elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) elif isinstance(node, exp.Extract): - if node.expression.type not in exp.DataType.TEMPORAL_TYPES: + if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: _replace_cast(node.expression, "datetime") return node +def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Cast) + and expression.to.type + and expression.this.type + and expression.to.type.this == expression.this.type.this + ): + return expression.this + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): - if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: + if ( + a.type + and a.type.this == exp.DataType.Type.DATE + and b.type + and b.type.this != exp.DataType.Type.DATE + ): _replace_cast(b, "date") diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c432c59..c0719f2 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -7,7 +7,7 @@ from decimal import Decimal from sqlglot import exp from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.generator import Generator -from sqlglot.helper import while_changing +from sqlglot.helper import first, while_changing GENERATOR = Generator(normalize=True, identify=True) @@ -30,6 +30,7 @@ def simplify(expression): def _simplify(expression, root=True): node = expression + node = rewrite_between(node) node = uniq_sort(node) node = absorb_and_eliminate(node) exp.replace_children(node, lambda e: _simplify(e, False)) @@ -49,6 +50,19 @@ def simplify(expression): return expression +def rewrite_between(expression: exp.Expression) -> exp.Expression: + """Rewrite x between y and z to x >= y AND x <= z. + + This is done because comparison simplification is only done on lt/lte/gt/gte. + """ + if isinstance(expression, exp.Between): + return exp.and_( + exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), + exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), + ) + return expression + + def simplify_not(expression): """ Demorgan's Law @@ -57,7 +71,7 @@ def simplify_not(expression): """ if isinstance(expression, exp.Not): if isinstance(expression.this, exp.Null): - return NULL + return exp.null() if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): @@ -65,11 +79,11 @@ def simplify_not(expression): if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Null): - return NULL + return exp.null() if always_true(expression.this): - return FALSE + return exp.false() if expression.this == FALSE: - return TRUE + return exp.true() if isinstance(expression.this, exp.Not): # double negation # NOT NOT x -> x @@ -91,40 +105,119 @@ def flatten(expression): def simplify_connectors(expression): - if isinstance(expression, exp.Connector): - left = expression.left - right = expression.right - - if left == right: - return left - - if isinstance(expression, exp.And): - if FALSE in (left, right): - return FALSE - if NULL in (left, right): - return NULL - if always_true(left) and always_true(right): - return TRUE - if always_true(left): - return right - if always_true(right): - return left - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return TRUE - if left == FALSE and right == FALSE: - return FALSE - if ( - (left == NULL and right == NULL) - or (left == NULL and right == FALSE) - or (left == FALSE and right == NULL) - ): - return NULL - if left == FALSE: - return right - if right == FALSE: + def _simplify_connectors(expression, left, right): + if isinstance(expression, exp.Connector): + if left == right: return left - return expression + if isinstance(expression, exp.And): + if FALSE in (left, right): + return exp.false() + if NULL in (left, right): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): + return left + return _simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if left == FALSE and right == FALSE: + return exp.false() + if ( + (left == NULL and right == NULL) + or (left == NULL and right == FALSE) + or (left == FALSE and right == NULL) + ): + return exp.null() + if left == FALSE: + return right + if right == FALSE: + return left + return _simplify_comparison(expression, left, right, or_=True) + return None + + return _flat_simplify(expression, _simplify_connectors) + + +LT_LTE = (exp.LT, exp.LTE) +GT_GTE = (exp.GT, exp.GTE) + +COMPARISONS = ( + *LT_LTE, + *GT_GTE, + exp.EQ, + exp.NEQ, +) + +INVERSE_COMPARISONS = { + exp.LT: exp.GT, + exp.GT: exp.LT, + exp.LTE: exp.GTE, + exp.GTE: exp.LTE, +} + + +def _simplify_comparison(expression, left, right, or_=False): + if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): + ll, lr = left.args.values() + rl, rr = right.args.values() + + largs = {ll, lr} + rargs = {rl, rr} + + matching = largs & rargs + columns = {m for m in matching if isinstance(m, exp.Column)} + + if matching and columns: + try: + l = first(largs - columns) + r = first(rargs - columns) + except StopIteration: + return expression + + # make sure the comparison is always of the form x > 1 instead of 1 < x + if left.__class__ in INVERSE_COMPARISONS and l == ll: + left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) + if right.__class__ in INVERSE_COMPARISONS and r == rl: + right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) + + if l.is_number and r.is_number: + l = float(l.name) + r = float(r.name) + elif l.is_string and r.is_string: + l = l.name + r = r.name + else: + return None + + for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): + if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): + return left if (av > bv if or_ else av <= bv) else right + if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): + return left if (av < bv if or_ else av >= bv) else right + + # we can't ever shortcut to true because the column could be null + if isinstance(a, exp.LT) and isinstance(b, GT_GTE): + if not or_ and av <= bv: + return exp.false() + elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): + if not or_ and av >= bv: + return exp.false() + elif isinstance(a, exp.EQ): + if isinstance(b, exp.LT): + return exp.false() if av >= bv else a + if isinstance(b, exp.LTE): + return exp.false() if av > bv else a + if isinstance(b, exp.GT): + return exp.false() if av <= bv else a + if isinstance(b, exp.GTE): + return exp.false() if av < bv else a + if isinstance(b, exp.NEQ): + return exp.false() if av == bv else a + return None def remove_compliments(expression): @@ -135,7 +228,7 @@ def remove_compliments(expression): A OR NOT A -> TRUE """ if isinstance(expression, exp.Connector): - compliment = FALSE if isinstance(expression, exp.And) else TRUE + compliment = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): if is_complement(a, b): @@ -211,27 +304,7 @@ def absorb_and_eliminate(expression): def simplify_literals(expression): if isinstance(expression, exp.Binary): - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) - - while queue: - a = queue.popleft() - - for b in queue: - result = _simplify_binary(expression, a, b) - - if result: - queue.remove(b) - queue.append(result) - break - else: - operands.append(a) - - if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) + return _flat_simplify(expression, _simplify_binary) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b): if c == NULL: if isinstance(a, exp.Literal): - return TRUE if not_ else FALSE + return exp.true() if not_ else exp.false() if a == NULL: - return FALSE if not_ else TRUE - elif isinstance(expression, exp.NullSafeEQ): - if a == b: - return TRUE - elif isinstance(expression, exp.NullSafeNEQ): - if a == b: - return FALSE + return exp.false() if not_ else exp.true() + elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): + return None elif NULL in (a, b): - return NULL - - if isinstance(expression, exp.EQ) and a == b: - return TRUE + return exp.null() if a.is_number and b.is_number: a = int(a.name) if a.is_int else Decimal(a.name) @@ -388,4 +454,27 @@ def date_literal(date): def boolean_literal(condition): - return TRUE if condition else FALSE + return exp.true() if condition else exp.false() + + +def _flat_simplify(expression, simplifier): + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = simplifier(expression, a, b) + + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + return expression diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bdf0d2d..55ab453 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -185,6 +185,7 @@ class Parser(metaclass=_Parser): TokenType.LOCAL, TokenType.LOCATION, TokenType.MATERIALIZED, + TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, @@ -211,7 +212,6 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TABLE_FORMAT, TokenType.TEMPORARY, - TokenType.TRANSIENT, TokenType.TOP, TokenType.TRAILING, TokenType.TRUE, @@ -229,6 +229,8 @@ class Parser(metaclass=_Parser): TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} + UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} + TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { @@ -241,6 +243,7 @@ class Parser(metaclass=_Parser): TokenType.FORMAT, TokenType.IDENTIFIER, TokenType.ISNULL, + TokenType.MERGE, TokenType.OFFSET, TokenType.PRIMARY_KEY, TokenType.REPLACE, @@ -407,6 +410,7 @@ class Parser(metaclass=_Parser): TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), + TokenType.MERGE: lambda self: self._parse_merge(), } UNARY_PARSERS = { @@ -474,6 +478,7 @@ class Parser(metaclass=_Parser): TokenType.SORTKEY: lambda self: self._parse_sortkey(), TokenType.LIKE: lambda self: self._parse_create_like(), TokenType.RETURNS: lambda self: self._parse_returns(), + TokenType.ROW: lambda self: self._parse_row(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), @@ -495,6 +500,8 @@ class Parser(metaclass=_Parser): TokenType.VOLATILE: lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") ), + TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property), + TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property), } CONSTRAINT_PARSERS = { @@ -802,7 +809,8 @@ class Parser(metaclass=_Parser): def _parse_create(self): replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) - transient = self._match(TokenType.TRANSIENT) + transient = self._match_text_seq("TRANSIENT") + external = self._match_text_seq("EXTERNAL") unique = self._match(TokenType.UNIQUE) materialized = self._match(TokenType.MATERIALIZED) @@ -846,6 +854,7 @@ class Parser(metaclass=_Parser): properties=properties, temporary=temporary, transient=transient, + external=external, replace=replace, unique=unique, materialized=materialized, @@ -861,8 +870,12 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): return self._parse_sortkey(compound=True) - if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): - key = self._parse_var() + assignment = self._match_pair( + TokenType.VAR, TokenType.EQ, advance=False + ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) + + if assignment: + key = self._parse_var() or self._parse_string() self._match(TokenType.EQ) return self.expression(exp.Property, this=key, value=self._parse_column()) @@ -871,7 +884,10 @@ class Parser(metaclass=_Parser): def _parse_property_assignment(self, exp_class): self._match(TokenType.EQ) self._match(TokenType.ALIAS) - return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number()) + return self.expression( + exp_class, + this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), + ) def _parse_partitioned_by(self): self._match(TokenType.EQ) @@ -881,7 +897,7 @@ class Parser(metaclass=_Parser): ) def _parse_distkey(self): - return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var)) + return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) def _parse_create_like(self): table = self._parse_table(schema=True) @@ -898,7 +914,7 @@ class Parser(metaclass=_Parser): def _parse_sortkey(self, compound=False): return self.expression( - exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound + exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound ) def _parse_character_set(self, default=False): @@ -929,23 +945,11 @@ class Parser(metaclass=_Parser): properties = [] while True: - if self._match(TokenType.WITH): - properties.extend(self._parse_wrapped_csv(self._parse_property)) - elif self._match(TokenType.PROPERTIES): - properties.extend( - self._parse_wrapped_csv( - lambda: self.expression( - exp.Property, - this=self._parse_string(), - value=self._match(TokenType.EQ) and self._parse_string(), - ) - ) - ) - else: - identified_property = self._parse_property() - if not identified_property: - break - properties.append(identified_property) + identified_property = self._parse_property() + if not identified_property: + break + for p in ensure_collection(identified_property): + properties.append(p) if properties: return self.expression(exp.Properties, expressions=properties) @@ -963,7 +967,7 @@ class Parser(metaclass=_Parser): exp.Directory, this=self._parse_var_or_string(), local=local, - row_format=self._parse_row_format(), + row_format=self._parse_row_format(match_row=True), ) else: self._match(TokenType.INTO) @@ -978,10 +982,18 @@ class Parser(metaclass=_Parser): overwrite=overwrite, ) - def _parse_row_format(self): - if not self._match_pair(TokenType.ROW, TokenType.FORMAT): + def _parse_row(self): + if not self._match(TokenType.FORMAT): + return None + return self._parse_row_format() + + def _parse_row_format(self, match_row=False): + if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None + if self._match_text_seq("SERDE"): + return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string()) + self._match_text_seq("DELIMITED") kwargs = {} @@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser): kwargs["lines"] = self._parse_string() if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() - return self.expression(exp.RowFormat, **kwargs) + return self.expression(exp.RowFormatDelimitedProperty, **kwargs) def _parse_load_data(self): local = self._match(TokenType.LOCAL) @@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Update, **{ - "this": self._parse_table(schema=True), + "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), "where": self._parse_where(), @@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser): alias=alias, ) - def _parse_table_alias(self): + def _parse_table_alias(self, alias_tokens=None): any_token = self._match(TokenType.ALIAS) - alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS) + alias = self._parse_id_var( + any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS + ) columns = None if self._match(TokenType.L_PAREN): @@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser): columns=self._parse_expression(), ) - def _parse_table(self, schema=False): + def _parse_table(self, schema=False, alias_tokens=None): lateral = self._parse_lateral() if lateral: @@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser): table = self._parse_id_var() if not table: - self.raise_error("Expected table name") + self.raise_error(f"Expected table name but got {self._curr}") this = self.expression( exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() @@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser): if self.alias_post_tablesample: table_sample = self._parse_table_sample() - alias = self._parse_table_alias() + alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) if alias: this.set("alias", alias) @@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser): kind = self.expression(exp.CheckColumnConstraint, this=constraint) elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) + elif self._match(TokenType.ENCODE): + kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() + elif self._match(TokenType.NULL): + kind = exp.NotNullColumnConstraint(allow_null=True) elif self._match(TokenType.SCHEMA_COMMENT): kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string()) elif self._match(TokenType.PRIMARY_KEY): @@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser): return self._parse_window(this) def _parse_extract(self): - this = self._parse_var() or self._parse_type() + this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) @@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser): parser = self._find_parser(self.SET_PARSERS, self._set_trie) return parser(self) if parser else self._default_parse_set_item() + def _parse_merge(self): + self._match(TokenType.INTO) + target = self._parse_table(schema=True) + + self._match(TokenType.USING) + using = self._parse_table() + + self._match(TokenType.ON) + on = self._parse_conjunction() + + whens = [] + while self._match(TokenType.WHEN): + this = self._parse_conjunction() + self._match(TokenType.THEN) + + if self._match(TokenType.INSERT): + _this = self._parse_star() + if _this: + then = self.expression(exp.Insert, this=_this) + else: + then = self.expression( + exp.Insert, + this=self._parse_value(), + expression=self._match(TokenType.VALUES) and self._parse_value(), + ) + elif self._match(TokenType.UPDATE): + expressions = self._parse_star() + if expressions: + then = self.expression(exp.Update, expressions=expressions) + else: + then = self.expression( + exp.Update, + expressions=self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + ) + elif self._match(TokenType.DELETE): + then = self.expression(exp.Var, this=self._prev.text) + + whens.append(self.expression(exp.When, this=this, then=then)) + + return self.expression( + exp.Merge, + this=target, + using=using, + on=on, + expressions=whens, + ) + def _parse_set(self): return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f6f303b..8a264a2 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -47,7 +47,7 @@ class Schema(abc.ABC): """ @abc.abstractmethod - def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type: + def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: """ Get the :class:`sqlglot.exp.DataType` type of a column in the schema. @@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): super().__init__(schema) self.visible = visible or {} self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = { - "STR": exp.DataType.Type.TEXT, + self._type_mapping_cache: t.Dict[str, exp.DataType] = { + "STR": exp.DataType.build("text"), } @classmethod @@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): visible = self._nested_get(self.table_parts(table_), self.visible) return [col for col in schema if col in visible] # type: ignore - def get_column_type( - self, table: exp.Table | str, column: exp.Column | str - ) -> exp.DataType.Type: + def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: column_name = column if isinstance(column, str) else column.name table_ = exp.to_table(table) if table_: - table_schema = self.find(table_) - schema_type = table_schema.get(column_name).upper() # type: ignore - return self._convert_type(schema_type) + table_schema = self.find(table_, raise_on_missing=False) + if table_schema: + schema_type = table_schema.get(column_name).upper() # type: ignore + return self._convert_type(schema_type) + return exp.DataType(this=exp.DataType.Type.UNKNOWN) raise SchemaError(f"Could not convert table '{table}'") - def _convert_type(self, schema_type: str) -> exp.DataType.Type: + def _convert_type(self, schema_type: str) -> exp.DataType: """ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. @@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) if expression is None: raise ValueError(f"Could not parse {schema_type}") - self._type_mapping_cache[schema_type] = expression.this + self._type_mapping_cache[schema_type] = expression # type: ignore except AttributeError: raise SchemaError(f"Failed to convert type {schema_type}") diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8a7a38e..b25ef8d 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -49,6 +49,9 @@ class TokenType(AutoName): PARAMETER = auto() SESSION_PARAMETER = auto() + BLOCK_START = auto() + BLOCK_END = auto() + SPACE = auto() BREAK = auto() @@ -156,6 +159,7 @@ class TokenType(AutoName): DIV = auto() DROP = auto() ELSE = auto() + ENCODE = auto() END = auto() ENGINE = auto() ESCAPE = auto() @@ -207,6 +211,7 @@ class TokenType(AutoName): LOCATION = auto() MAP = auto() MATERIALIZED = auto() + MERGE = auto() MOD = auto() NATURAL = auto() NEXT = auto() @@ -255,6 +260,7 @@ class TokenType(AutoName): SELECT = auto() SEMI = auto() SEPARATOR = auto() + SERDE_PROPERTIES = auto() SET = auto() SHOW = auto() SIMILAR_TO = auto() @@ -267,7 +273,6 @@ class TokenType(AutoName): TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() - TRANSIENT = auto() TOP = auto() THEN = auto() TRAILING = auto() @@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer): ESCAPES = ["'"] KEYWORDS = { + **{ + f"{key}{postfix}": TokenType.BLOCK_START + for key in ("{{", "{%", "{#") + for postfix in ("", "+", "-") + }, + **{ + f"{prefix}{key}": TokenType.BLOCK_END + for key in ("}}", "%}", "#}") + for prefix in ("", "+", "-") + }, "/*+": TokenType.HINT, "==": TokenType.EQ, "::": TokenType.DCOLON, @@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer): "LOCAL": TokenType.LOCAL, "LOCATION": TokenType.LOCATION, "MATERIALIZED": TokenType.MATERIALIZED, + "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NO ACTION": TokenType.NO_ACTION, @@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer): "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY, - "TRANSIENT": TokenType.TRANSIENT, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "TRAILING": TokenType.TRAILING, -- cgit v1.2.3