diff options
Diffstat (limited to 'sqlglot')
26 files changed, 599 insertions, 124 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 42801ac..be10f3d 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -67,19 +67,22 @@ schema = MappingSchema() """The default schema used by SQLGlot (e.g. in the optimizer).""" -def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expression]]: +def parse( + sql: str, read: DialectType = None, dialect: DialectType = None, **opts +) -> t.List[t.Optional[Expression]]: """ Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. Args: sql: the SQL code string to parse. read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read). **opts: other `sqlglot.parser.Parser` options. Returns: The resulting syntax tree collection. """ - dialect = Dialect.get_or_raise(read)() + dialect = Dialect.get_or_raise(read or dialect)() return dialect.parse(sql, **opts) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 1549a07..4002cfe 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -386,7 +386,7 @@ def input_file_name() -> Column: def isnan(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ISNAN") + return Column.invoke_expression_over_column(col, expression.IsNan) def isnull(col: ColumnOrName) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index fd9965c..df9065f 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -211,6 +211,10 @@ class BigQuery(Dialect): "TZH": "%z", } + # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement + # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table + PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"} + @classmethod def normalize_identifier(cls, expression: E) -> E: # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 8f60df2..ce1a486 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -380,7 +380,7 @@ class ClickHouse(Dialect): ] def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: - params = self.expressions(expression, "params", flat=True) + params = self.expressions(expression, key="params", flat=True) return self.func(expression.name, *expression.expressions) + f"({params})" def placeholder_sql(self, expression: exp.Placeholder) -> str: diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 8c84639..05e81ce 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -5,6 +5,7 @@ from enum import Enum from sqlglot import exp from sqlglot._typing import E +from sqlglot.errors import ParseError from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser @@ -168,6 +169,10 @@ class Dialect(metaclass=_Dialect): # special syntax cast(x as date format 'yyyy') defaults to time_mapping FORMAT_MAPPING: t.Dict[str, str] = {} + # Columns that are auto-generated by the engine corresponding to this dialect + # Such columns may be excluded from SELECT * queries, for example + PSEUDOCOLUMNS: t.Set[str] = set() + # Autofilled tokenizer_class = Tokenizer parser_class = Parser @@ -497,6 +502,10 @@ def parse_date_delta_with_interval( return None interval = args[1] + + if not isinstance(interval, exp.Interval): + raise ParseError(f"INTERVAL expression expected but got '{interval}'") + expression = interval.this if expression and expression.is_string: expression = exp.Literal.number(expression.this) @@ -555,11 +564,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: - return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" + return self.sql(exp.cast(expression.this, "timestamp")) def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: - return f"CAST({self.sql(expression, 'this')} AS DATE)" + return self.sql(exp.cast(expression.this, "date")) def min_or_least(self: Generator, expression: exp.Min) -> str: @@ -608,8 +617,9 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: _dialect = Dialect.get_or_raise(dialect) time_format = self.format_time(expression) if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): - return f"CAST({str_to_time_sql(self, expression)} AS DATE)" - return f"CAST({self.sql(expression, 'this')} AS DATE)" + return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) + + return self.sql(exp.cast(self.sql(expression, "this"), "date")) return _ts_or_ds_to_date_sql @@ -664,5 +674,15 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp return names +def simplify_literal(expression: E, copy: bool = True) -> E: + if not isinstance(expression.expression, exp.Literal): + from sqlglot.optimizer.simplify import simplify + + expression = exp.maybe_copy(expression, copy) + simplify(expression.expression) + + return expression + + def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index e131434..4e84085 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -359,14 +359,16 @@ class Hive(Dialect): TABLE_HINTS = False QUERY_HINTS = False INDEX_ON = "ON TABLE" + EXTRACT_ALLOWS_QUOTES = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.BIT: "BOOLEAN", exp.DataType.Type.DATETIME: "TIMESTAMP", - exp.DataType.Type.VARBINARY: "BINARY", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIME: "TIMESTAMP", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.BIT: "BOOLEAN", + exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { @@ -396,6 +398,7 @@ class Hive(Dialect): exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, exp.ILike: no_ilike_sql, + exp.IsNan: rename_func("ISNAN"), exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONFormat: _json_format_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5d65f77..a54f076 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, parse_date_delta_with_interval, rename_func, + simplify_literal, strposition_to_locate_sql, ) from sqlglot.helper import seq_get @@ -303,6 +304,22 @@ class MySQL(Dialect): "NAMES": lambda self: self._parse_set_item_names(), } + CONSTRAINT_PARSERS = { + **parser.Parser.CONSTRAINT_PARSERS, + "FULLTEXT": lambda self: self._parse_index_constraint(kind="FULLTEXT"), + "INDEX": lambda self: self._parse_index_constraint(), + "KEY": lambda self: self._parse_index_constraint(), + "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"), + } + + SCHEMA_UNNAMED_CONSTRAINTS = { + *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, + "FULLTEXT", + "INDEX", + "KEY", + "SPATIAL", + } + PROFILE_TYPES = { "ALL", "BLOCK IO", @@ -327,6 +344,57 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True + def _parse_index_constraint( + self, kind: t.Optional[str] = None + ) -> exp.IndexColumnConstraint: + if kind: + self._match_texts({"INDEX", "KEY"}) + + this = self._parse_id_var(any_token=False) + type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text + schema = self._parse_schema() + + options = [] + while True: + if self._match_text_seq("KEY_BLOCK_SIZE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(key_block_size=self._parse_number()) + elif self._match(TokenType.USING): + opt = exp.IndexConstraintOption(using=self._advance_any() and self._prev.text) + elif self._match_text_seq("WITH", "PARSER"): + opt = exp.IndexConstraintOption(parser=self._parse_var(any_token=True)) + elif self._match(TokenType.COMMENT): + opt = exp.IndexConstraintOption(comment=self._parse_string()) + elif self._match_text_seq("VISIBLE"): + opt = exp.IndexConstraintOption(visible=True) + elif self._match_text_seq("INVISIBLE"): + opt = exp.IndexConstraintOption(visible=False) + elif self._match_text_seq("ENGINE_ATTRIBUTE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) + elif self._match_text_seq("ENGINE_ATTRIBUTE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) + elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"): + self._match(TokenType.EQ) + opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string()) + else: + opt = None + + if not opt: + break + + options.append(opt) + + return self.expression( + exp.IndexColumnConstraint, + this=this, + schema=schema, + kind=kind, + type=type_, + options=options, + ) + def _parse_show_mysql( self, this: str, @@ -454,6 +522,7 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")), exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, @@ -485,6 +554,16 @@ class MySQL(Dialect): exp.DataType.Type.VARCHAR: "CHAR", } + def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: + # MySQL requires simple literal values for its LIMIT clause. + expression = simplify_literal(expression) + return super().limit_sql(expression, top=top) + + def offset_sql(self, expression: exp.Offset) -> str: + # MySQL requires simple literal values for its OFFSET clause. + expression = simplify_literal(expression) + return super().offset_sql(expression) + def xor_sql(self, expression: exp.Xor) -> str: if expression.expressions: return self.expressions(expression, sep=" XOR ") diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 69da133..1f63e9f 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -30,6 +30,9 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: class Oracle(Dialect): ALIAS_POST_TABLESAMPLE = True + # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm + RESOLVES_IDENTIFIERS_AS_UPPERCASE = True + # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes TIME_MAPPING = { diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index d11cbd7..ef100b1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, rename_func, + simplify_literal, str_position_sql, timestamptrunc_sql, timestrtotime_sql, @@ -39,16 +40,13 @@ DATE_DIFF_FACTOR = { def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - from sqlglot.optimizer.simplify import simplify - this = self.sql(expression, "this") unit = expression.args.get("unit") - expression = simplify(expression.args["expression"]) + expression = simplify_literal(expression.copy(), copy=False).expression if not isinstance(expression, exp.Literal): self.unsupported("Cannot add non literal") - expression = expression.copy() expression.args["is_string"] = True return f"{this} {kind} {self.sql(exp.Interval(this=expression, unit=unit))}" diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 265c6e5..14ec3dd 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -192,6 +192,8 @@ class Presto(Dialect): "START": TokenType.BEGIN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "ROW": TokenType.STRUCT, + "IPADDRESS": TokenType.IPADDRESS, + "IPPREFIX": TokenType.IPPREFIX, } class Parser(parser.Parser): diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 73f4370..b9aaa66 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp +from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get @@ -47,7 +48,11 @@ class Spark(Spark2): exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)", exp.DataType.Type.UNIQUEIDENTIFIER: "STRING", } - TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() + + TRANSFORMS = { + **Spark2.Generator.TRANSFORMS, + exp.StartsWith: rename_func("STARTSWITH"), + } TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index dcaa524..ceb48f8 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -19,9 +19,13 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str: kind = e.args["kind"] properties = e.args.get("properties") - if kind.upper() == "TABLE" and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) + if ( + kind.upper() == "TABLE" + and e.expression + and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) ): return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" return create_with_partitions_sql(self, e) diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 4e8ffb4..3fac4f5 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -33,8 +33,10 @@ class Teradata(Dialect): **tokens.Tokenizer.KEYWORDS, "^=": TokenType.NEQ, "BYTEINT": TokenType.SMALLINT, + "COLLECT": TokenType.COMMAND, "GE": TokenType.GTE, "GT": TokenType.GT, + "HELP": TokenType.COMMAND, "INS": TokenType.INSERT, "LE": TokenType.LTE, "LT": TokenType.LT, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 01d5001..0eb0906 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import re import typing as t @@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import ( min_or_least, parse_date_delta, rename_func, + timestrtotime_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{ # N = Numeric, C=Currency TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} +DEFAULT_START_DATE = datetime.date(1900, 1, 1) + def _format_time_lambda( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None @@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s return f"STRING_AGG({self.format_args(this, separator)}){order}" +def _parse_date_delta( + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None +) -> t.Callable[[t.List], E]: + def inner_func(args: t.List) -> E: + unit = seq_get(args, 0) + if unit and unit_mapping: + unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) + + start_date = seq_get(args, 1) + if start_date and start_date.is_number: + # Numeric types are valid DATETIME values + if start_date.is_int: + adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this)) + start_date = exp.Literal.string(adds.strftime("%F")) + else: + # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs. + # This is not a problem when generating T-SQL code, it is when transpiling to other dialects. + return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit) + + return exp_class( + this=exp.TimeStrToTime(this=seq_get(args, 2)), + expression=exp.TimeStrToTime(this=start_date), + unit=unit, + ) + + return inner_func + + class TSQL(Dialect): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" @@ -298,7 +330,6 @@ class TSQL(Dialect): "SMALLDATETIME": TokenType.DATETIME, "SMALLMONEY": TokenType.SMALLMONEY, "SQL_VARIANT": TokenType.VARIANT, - "TIME": TokenType.TIMESTAMP, "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, @@ -307,10 +338,6 @@ class TSQL(Dialect): "SYSTEM_USER": TokenType.CURRENT_USER, } - # TSQL allows @, # to appear as a variable/identifier prefix - SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("#") - class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -320,7 +347,7 @@ class TSQL(Dialect): position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), - "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), "EOMONTH": _parse_eomonth, @@ -518,6 +545,36 @@ class TSQL(Dialect): expressions = self._parse_csv(self._parse_function_parameter) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + is_temporary = self._match(TokenType.HASH) + is_global = is_temporary and self._match(TokenType.HASH) + + this = super()._parse_id_var(any_token=any_token, tokens=tokens) + if this: + if is_global: + this.set("global", True) + elif is_temporary: + this.set("temporary", True) + + return this + + def _parse_create(self) -> exp.Create | exp.Command: + create = super()._parse_create() + + if isinstance(create, exp.Create): + table = create.this.this if isinstance(create.this, exp.Schema) else create.this + if isinstance(table, exp.Table) and table.this.args.get("temporary"): + if not create.args.get("properties"): + create.set("properties", exp.Properties(expressions=[])) + + create.args["properties"].append("expressions", exp.TemporaryProperty()) + + return create + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True @@ -526,9 +583,11 @@ class TSQL(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.TIMESTAMP: "DATETIME2", + exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", exp.DataType.Type.VARIANT: "SQL_VARIANT", } @@ -552,6 +611,8 @@ class TSQL(Dialect): exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this, ), + exp.TemporaryProperty: lambda self, e: "", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, } @@ -564,6 +625,22 @@ class TSQL(Dialect): LIMIT_FETCH = "FETCH" + def createable_sql( + self, + expression: exp.Create, + locations: dict[exp.Properties.Location, list[exp.Property]], + ) -> str: + sql = self.sql(expression, "this") + properties = expression.args.get("properties") + + if sql[:1] != "#" and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ): + sql = f"#{sql}" + + return sql + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" @@ -616,3 +693,13 @@ class TSQL(Dialect): this = self.sql(expression, "this") this = f" {this}" if this else "" return f"ROLLBACK TRANSACTION{this}" + + def identifier_sql(self, expression: exp.Identifier) -> str: + identifier = super().identifier_sql(expression) + + if expression.args.get("global"): + identifier = f"##{identifier}" + elif expression.args.get("temporary"): + identifier = f"#{identifier}" + + return identifier diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 9a6b440..f8e9fee 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -67,8 +67,9 @@ class Expression(metaclass=_Expression): uses to refer to it. comments: a list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code. - _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the + type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the optimizer, in order to enable some transformations that require type information. + meta: a dictionary that can be used to store useful metadata for a given expression. Example: >>> class Foo(Expression): @@ -767,7 +768,7 @@ class Condition(Expression): **opts, ) -> In: return In( - this=_maybe_copy(self, copy), + this=maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], query=maybe_parse(query, copy=copy, **opts) if query else None, unnest=Unnest( @@ -781,7 +782,7 @@ class Condition(Expression): def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between: return Between( - this=_maybe_copy(self, copy), + this=maybe_copy(self, copy), low=convert(low, copy=copy, **opts), high=convert(high, copy=copy, **opts), ) @@ -990,7 +991,28 @@ class Uncache(Expression): arg_types = {"this": True, "exists": False} -class Create(Expression): +class DDL(Expression): + @property + def ctes(self): + with_ = self.args.get("with") + if not with_: + return [] + return with_.expressions + + @property + def named_selects(self) -> t.List[str]: + if isinstance(self.expression, Subqueryable): + return self.expression.named_selects + return [] + + @property + def selects(self) -> t.List[Expression]: + if isinstance(self.expression, Subqueryable): + return self.expression.selects + return [] + + +class Create(DDL): arg_types = { "with": False, "this": True, @@ -1206,6 +1228,19 @@ class MergeTreeTTL(Expression): } +# https://dev.mysql.com/doc/refman/8.0/en/create-table.html +class IndexConstraintOption(Expression): + arg_types = { + "key_block_size": False, + "using": False, + "parser": False, + "comment": False, + "visible": False, + "engine_attr": False, + "secondary_engine_attr": False, + } + + class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} @@ -1272,6 +1307,11 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): } +# https://dev.mysql.com/doc/refman/8.0/en/create-table.html +class IndexColumnConstraint(ColumnConstraintKind): + arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False} + + class InlineLengthColumnConstraint(ColumnConstraintKind): pass @@ -1496,7 +1536,7 @@ class JoinHint(Expression): class Identifier(Expression): - arg_types = {"this": True, "quoted": False} + arg_types = {"this": True, "quoted": False, "global": False, "temporary": False} @property def quoted(self) -> bool: @@ -1525,7 +1565,7 @@ class Index(Expression): } -class Insert(Expression): +class Insert(DDL): arg_types = { "with": False, "this": True, @@ -1892,6 +1932,10 @@ class EngineProperty(Property): arg_types = {"this": True} +class HeapProperty(Property): + arg_types = {} + + class ToTableProperty(Property): arg_types = {"this": True} @@ -2182,7 +2226,7 @@ class Tuple(Expression): **opts, ) -> In: return In( - this=_maybe_copy(self, copy), + this=maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], query=maybe_parse(query, copy=copy, **opts) if query else None, unnest=Unnest( @@ -2212,7 +2256,7 @@ class Subqueryable(Unionable): Returns: Alias: the subquery """ - instance = _maybe_copy(self, copy) + instance = maybe_copy(self, copy) if not isinstance(alias, Expression): alias = TableAlias(this=to_identifier(alias)) if alias else None @@ -2865,7 +2909,7 @@ class Select(Subqueryable): self, expression: ExpOrStr, on: t.Optional[ExpOrStr] = None, - using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None, + using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None, append: bool = True, join_type: t.Optional[str] = None, join_alias: t.Optional[Identifier | str] = None, @@ -2943,6 +2987,7 @@ class Select(Subqueryable): arg="using", append=append, copy=copy, + into=Identifier, **opts, ) @@ -3092,7 +3137,7 @@ class Select(Subqueryable): Returns: Select: the modified expression. """ - instance = _maybe_copy(self, copy) + instance = maybe_copy(self, copy) on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None instance.set("distinct", Distinct(on=on) if distinct else None) return instance @@ -3123,7 +3168,7 @@ class Select(Subqueryable): Returns: The new Create expression. """ - instance = _maybe_copy(self, copy) + instance = maybe_copy(self, copy) table_expression = maybe_parse( table, into=Table, @@ -3159,7 +3204,7 @@ class Select(Subqueryable): Returns: The modified expression. """ - inst = _maybe_copy(self, copy) + inst = maybe_copy(self, copy) inst.set("locks", [Lock(update=update)]) return inst @@ -3181,7 +3226,7 @@ class Select(Subqueryable): Returns: The modified expression. """ - inst = _maybe_copy(self, copy) + inst = maybe_copy(self, copy) inst.set( "hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints]) ) @@ -3376,6 +3421,8 @@ class DataType(Expression): HSTORE = auto() IMAGE = auto() INET = auto() + IPADDRESS = auto() + IPPREFIX = auto() INT = auto() INT128 = auto() INT256 = auto() @@ -3987,7 +4034,7 @@ class Case(Func): arg_types = {"this": False, "ifs": True, "default": False} def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case: - instance = _maybe_copy(self, copy) + instance = maybe_copy(self, copy) instance.append( "ifs", If( @@ -3998,7 +4045,7 @@ class Case(Func): return instance def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: - instance = _maybe_copy(self, copy) + instance = maybe_copy(self, copy) instance.set("default", maybe_parse(condition, copy=copy, **opts)) return instance @@ -4263,6 +4310,10 @@ class Initcap(Func): arg_types = {"this": True, "expression": False} +class IsNan(Func): + _sql_names = ["IS_NAN", "ISNAN"] + + class JSONKeyValue(Expression): arg_types = {"this": True, "expression": True} @@ -4549,6 +4600,11 @@ class StandardHash(Func): arg_types = {"this": True, "expression": False} +class StartsWith(Func): + _sql_names = ["STARTS_WITH", "STARTSWITH"] + arg_types = {"this": True, "expression": True} + + class StrPosition(Func): arg_types = { "this": True, @@ -4804,7 +4860,7 @@ def maybe_parse( return sqlglot.parse_one(sql, read=dialect, into=into, **opts) -def _maybe_copy(instance: E, copy: bool = True) -> E: +def maybe_copy(instance: E, copy: bool = True) -> E: return instance.copy() if copy else instance @@ -4824,7 +4880,7 @@ def _apply_builder( ): if _is_wrong_expression(expression, into): expression = into(this=expression) - instance = _maybe_copy(instance, copy) + instance = maybe_copy(instance, copy) expression = maybe_parse( sql_or_expression=expression, prefix=prefix, @@ -4848,7 +4904,7 @@ def _apply_child_list_builder( properties=None, **opts, ): - instance = _maybe_copy(instance, copy) + instance = maybe_copy(instance, copy) parsed = [] for expression in expressions: if expression is not None: @@ -4887,7 +4943,7 @@ def _apply_list_builder( dialect=None, **opts, ): - inst = _maybe_copy(instance, copy) + inst = maybe_copy(instance, copy) expressions = [ maybe_parse( @@ -4923,7 +4979,7 @@ def _apply_conjunction_builder( if not expressions: return instance - inst = _maybe_copy(instance, copy) + inst = maybe_copy(instance, copy) existing = inst.args.get(arg) if append and existing is not None: @@ -5398,7 +5454,7 @@ def to_identifier(name, quoted=None, copy=True): return None if isinstance(name, Identifier): - identifier = _maybe_copy(name, copy) + identifier = maybe_copy(name, copy) elif isinstance(name, str): identifier = Identifier( this=name, @@ -5735,7 +5791,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression: Expression: the equivalent expression object. """ if isinstance(value, Expression): - return _maybe_copy(value, copy) + return maybe_copy(value, copy) if isinstance(value, str): return Literal.string(value) if isinstance(value, bool): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 40ba88e..ed0a681 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -68,6 +68,7 @@ class Generator: exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.ExternalProperty: lambda self, e: "EXTERNAL", + exp.HeapProperty: lambda self, e: "HEAP", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), @@ -161,6 +162,9 @@ class Generator: # Whether or not to generate the (+) suffix for columns used in old-style join conditions COLUMN_JOIN_MARKS_SUPPORTED = False + # Whether or not to generate an unquoted value for EXTRACT's date part argument + EXTRACT_ALLOWS_QUOTES = True + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") @@ -224,6 +228,7 @@ class Generator: exp.FallbackProperty: exp.Properties.Location.POST_NAME, exp.FileFormatProperty: exp.Properties.Location.POST_WITH, exp.FreespaceProperty: exp.Properties.Location.POST_NAME, + exp.HeapProperty: exp.Properties.Location.POST_WITH, exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, exp.JournalProperty: exp.Properties.Location.POST_NAME, exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, @@ -265,9 +270,12 @@ class Generator: # Expressions whose comments are separated from them for better formatting WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Delete, exp.Drop, exp.From, + exp.Insert, exp.Select, + exp.Update, exp.Where, exp.With, ) @@ -985,8 +993,9 @@ class Generator: ) -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - expressions = self.wrap(expressions) if wrapped else expressions - return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" + if expressions: + expressions = self.wrap(expressions) if wrapped else expressions + return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" return "" def with_properties(self, properties: exp.Properties) -> str: @@ -1905,7 +1914,7 @@ class Generator: return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}" def extract_sql(self, expression: exp.Extract) -> str: - this = self.sql(expression, "this") + this = self.sql(expression, "this") if self.EXTRACT_ALLOWS_QUOTES else expression.this.name expression_sql = self.sql(expression, "expression") return f"EXTRACT({this} FROM {expression_sql})" @@ -2370,7 +2379,12 @@ class Generator: elif arg_value is not None: args.append(arg_value) - return self.func(expression.sql_name(), *args) + if self.normalize_functions: + name = expression.sql_name() + else: + name = (expression._meta and expression.meta.get("name")) or expression.sql_name() + + return self.func(name, *args) def func( self, @@ -2412,7 +2426,7 @@ class Generator: return "" if flat: - return sep.join(self.sql(e) for e in expressions) + return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql) num_sqls = len(expressions) @@ -2423,6 +2437,9 @@ class Generator: result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) + if not sql: + continue + comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" if self.pretty: @@ -2562,6 +2579,51 @@ class Generator: record_reader = f" RECORDREADER {record_reader}" if record_reader else "" return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}" + def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str: + key_block_size = self.sql(expression, "key_block_size") + if key_block_size: + return f"KEY_BLOCK_SIZE = {key_block_size}" + + using = self.sql(expression, "using") + if using: + return f"USING {using}" + + parser = self.sql(expression, "parser") + if parser: + return f"WITH PARSER {parser}" + + comment = self.sql(expression, "comment") + if comment: + return f"COMMENT {comment}" + + visible = expression.args.get("visible") + if visible is not None: + return "VISIBLE" if visible else "INVISIBLE" + + engine_attr = self.sql(expression, "engine_attr") + if engine_attr: + return f"ENGINE_ATTRIBUTE = {engine_attr}" + + secondary_engine_attr = self.sql(expression, "secondary_engine_attr") + if secondary_engine_attr: + return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}" + + self.unsupported("Unsupported index constraint option.") + return "" + + def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: + kind = self.sql(expression, "kind") + kind = f"{kind} INDEX" if kind else "INDEX" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + type_ = self.sql(expression, "type") + type_ = f" USING {type_}" if type_ else "" + schema = self.sql(expression, "schema") + schema = f" {schema}" if schema else "" + options = self.expressions(expression, key="options", sep=" ") + options = f" {options}" if options else "" + return f"{kind}{this}{type_}{schema}{options}" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 728493d..af42f25 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -136,8 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): - # This ensures we don't drop the "pivot" arg from a pivoted subquery - if scope.parent.pivots: + # This makes sure that we don't: + # - drop the "pivot" arg from a pivoted subquery + # - eliminate a lateral correlated subquery + if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral): return None parent = scope.expression.parent diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 99e605d..9d4860e 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -1,8 +1,23 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType +@t.overload def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: + ... + + +@t.overload +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression: + ... + + +def normalize_identifiers(expression, dialect=None): """ Normalize all unquoted identifiers to either lower or upper case, depending on the dialect. This essentially makes those identifiers case-insensitive. @@ -16,6 +31,8 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') >>> normalize_identifiers(expression).sql() 'SELECT bar.a AS a FROM "Foo".bar' + >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake") + 'FOO' Args: expression: The expression to transform. @@ -24,4 +41,5 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: Returns: The transformed expression. """ + expression = exp.maybe_parse(expression, dialect=dialect) return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 2657188..9c34cef 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -39,6 +39,7 @@ def qualify_columns( """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema + pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS for scope in traverse_scope(expression): resolver = Resolver(scope, schema, infer_schema=infer_schema) @@ -55,7 +56,7 @@ def qualify_columns( _expand_alias_refs(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver, using_column_tables) + _expand_stars(scope, resolver, using_column_tables, pseudocolumns) _qualify_outputs(scope) _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -326,7 +327,10 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None: def _expand_stars( - scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any] + scope: Scope, + resolver: Resolver, + using_column_tables: t.Dict[str, t.Any], + pseudocolumns: t.Set[str], ) -> None: """Expand stars to lists of column selections""" @@ -367,14 +371,8 @@ def _expand_stars( columns = resolver.get_source_columns(table, only_visible=True) - # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement - # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table - if resolver.schema.dialect == "bigquery": - columns = [ - name - for name in columns - if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE") - ] + if pseudocolumns: + columns = [name for name in columns if name.upper() not in pseudocolumns] if columns and "*" not in columns: if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 31c9cc0..68aebdb 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -80,7 +80,9 @@ def qualify_tables( header = next(reader) columns = next(reader) schema.add_table( - source, {k: type(v).__name__ for k, v in zip(header, columns)} + source, + {k: type(v).__name__ for k, v in zip(header, columns)}, + match_depth=False, ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index a7dab35..fb12384 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -435,7 +435,10 @@ class Scope: @property def is_correlated_subquery(self): """Determine if this scope is a correlated subquery""" - return bool(self.is_subquery and self.external_columns) + return bool( + (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) + and self.external_columns + ) def rename_source(self, old_name, new_name): """Rename a source in this scope""" @@ -486,7 +489,7 @@ class Scope: def traverse_scope(expression: exp.Expression) -> t.List[Scope]: """ - Traverse an expression by it's "scopes". + Traverse an expression by its "scopes". "Scope" represents the current context of a Select statement. @@ -509,9 +512,12 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Returns: list[Scope]: scope instances """ - if not isinstance(expression, exp.Unionable): - return [] - return list(_traverse_scope(Scope(expression))) + if isinstance(expression, exp.Unionable) or ( + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable) + ): + return list(_traverse_scope(Scope(expression))) + + return [] def build_scope(expression: exp.Expression) -> t.Optional[Scope]: @@ -539,7 +545,9 @@ def _traverse_scope(scope): elif isinstance(scope.expression, exp.Table): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): - pass + yield from _traverse_udtfs(scope) + elif isinstance(scope.expression, exp.DDL): + yield from _traverse_ddl(scope) else: logger.warning( "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) @@ -576,10 +584,10 @@ def _traverse_ctes(scope): for cte in scope.ctes: recursive_scope = None - # if the scope is a recursive cte, it must be in the form of - # base_case UNION recursive. thus the recursive scope is the first - # section of the union. - if scope.expression.args["with"].recursive: + # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. + # thus the recursive scope is the first section of the union. + with_ = scope.expression.args.get("with") + if with_ and with_.recursive: union = cte.this if isinstance(union, exp.Union): @@ -692,8 +700,7 @@ def _traverse_tables(scope): # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - alias = expression.alias - sources[alias] = child_scope + sources[expression.alias] = child_scope # append the final child_scope yielded scopes.append(child_scope) @@ -711,6 +718,47 @@ def _traverse_subqueries(scope): scope.subquery_scopes.append(top) +def _traverse_udtfs(scope): + if isinstance(scope.expression, exp.Unnest): + expressions = scope.expression.expressions + elif isinstance(scope.expression, exp.Lateral): + expressions = [scope.expression.this] + else: + expressions = [] + + sources = {} + for expression in expressions: + if isinstance(expression, exp.Subquery) and _is_derived_table(expression): + top = None + for child_scope in _traverse_scope( + scope.branch( + expression, + scope_type=ScopeType.DERIVED_TABLE, + outer_column_list=expression.alias_column_names, + ) + ): + yield child_scope + top = child_scope + sources[expression.alias] = child_scope + + scope.derived_table_scopes.append(top) + scope.table_scopes.append(top) + + scope.sources.update(sources) + + +def _traverse_ddl(scope): + yield from _traverse_ctes(scope) + + query_scope = scope.branch( + scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources + ) + query_scope._collect() + query_scope._ctes = scope.ctes + query_scope._ctes + + yield from _traverse_scope(query_scope) + + def walk_in_scope(expression, bfs=True): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 09e3f2a..816f5fb 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -46,20 +46,24 @@ def unnest(select, parent_select, next_alias_name): if not predicate or parent_select is not predicate.parent_select: return - # this subquery returns a scalar and can just be converted to a cross join + # This subquery returns a scalar and can just be converted to a cross join if not isinstance(predicate, (exp.In, exp.Any)): - having = predicate.find_ancestor(exp.Having) column = exp.column(select.selects[0].alias_or_name, alias) - if having and having.parent_select is parent_select: + + clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) + clause_parent_select = clause.parent_select if clause else None + + if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( + (not clause or clause_parent_select is not parent_select) + and ( + parent_select.args.get("group") + or any(projection.find(exp.AggFunc) for projection in parent_select.selects) + ) + ): column = exp.Max(this=column) - _replace(select.parent, column) - parent_select.join( - select, - join_type="CROSS", - join_alias=alias, - copy=False, - ) + _replace(select.parent, column) + parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) return if select.find(exp.Limit, exp.Offset): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 5adec77..f714c8d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -185,6 +185,8 @@ class Parser(metaclass=_Parser): TokenType.VARIANT, TokenType.OBJECT, TokenType.INET, + TokenType.IPADDRESS, + TokenType.IPPREFIX, TokenType.ENUM, *NESTED_TYPE_TOKENS, } @@ -603,6 +605,7 @@ class Parser(metaclass=_Parser): "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs), "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "FREESPACE": lambda self: self._parse_freespace(), + "HEAP": lambda self: self.expression(exp.HeapProperty), "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), @@ -832,6 +835,7 @@ class Parser(metaclass=_Parser): UNNEST_COLUMN_ONLY: bool = False ALIAS_POST_TABLESAMPLE: bool = False STRICT_STRING_CONCAT = False + NORMALIZE_FUNCTIONS = "upper" NULL_ORDERING: str = "nulls_are_small" SHOW_TRIE: t.Dict = {} SET_TRIE: t.Dict = {} @@ -1187,7 +1191,7 @@ class Parser(metaclass=_Parser): exists = self._parse_exists(not_=True) this = None - expression = None + expression: t.Optional[exp.Expression] = None indexes = None no_schema_binding = None begin = None @@ -1207,12 +1211,16 @@ class Parser(metaclass=_Parser): extend_props(self._parse_properties()) self._match(TokenType.ALIAS) - begin = self._match(TokenType.BEGIN) - return_ = self._match_text_seq("RETURN") - expression = self._parse_statement() - if return_: - expression = self.expression(exp.Return, this=expression) + if self._match(TokenType.COMMAND): + expression = self._parse_as_command(self._prev) + else: + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + expression = self._parse_statement() + + if return_: + expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: @@ -1692,6 +1700,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Describe, this=this, kind=kind) def _parse_insert(self) -> exp.Insert: + comments = ensure_list(self._prev_comments) overwrite = self._match(TokenType.OVERWRITE) ignore = self._match(TokenType.IGNORE) local = self._match_text_seq("LOCAL") @@ -1709,6 +1718,7 @@ class Parser(metaclass=_Parser): alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text self._match(TokenType.INTO) + comments += ensure_list(self._prev_comments) self._match(TokenType.TABLE) this = self._parse_table(schema=True) @@ -1716,6 +1726,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Insert, + comments=comments, this=this, exists=self._parse_exists(), partition=self._parse_partition(), @@ -1840,6 +1851,7 @@ class Parser(metaclass=_Parser): # This handles MySQL's "Multiple-Table Syntax" # https://dev.mysql.com/doc/refman/8.0/en/delete.html tables = None + comments = self._prev_comments if not self._match(TokenType.FROM, advance=False): tables = self._parse_csv(self._parse_table) or None @@ -1847,6 +1859,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Delete, + comments=comments, tables=tables, this=self._match(TokenType.FROM) and self._parse_table(joins=True), using=self._match(TokenType.USING) and self._parse_table(joins=True), @@ -1856,11 +1869,13 @@ class Parser(metaclass=_Parser): ) def _parse_update(self) -> exp.Update: + comments = self._prev_comments this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS) expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) returning = self._parse_returning() return self.expression( exp.Update, + comments=comments, **{ # type: ignore "this": this, "expressions": expressions, @@ -2235,7 +2250,12 @@ class Parser(metaclass=_Parser): return None if not this: - this = self._parse_function() or self._parse_id_var(any_token=False) + this = ( + self._parse_unnest() + or self._parse_function() + or self._parse_id_var(any_token=False) + ) + while self._match(TokenType.DOT): this = exp.Dot( this=this, @@ -3341,7 +3361,10 @@ class Parser(metaclass=_Parser): args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) if function and not anonymous: - this = self.validate_expression(function(args), args) + func = self.validate_expression(function(args), args) + if not self.NORMALIZE_FUNCTIONS: + func.meta["name"] = this + this = func else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -3842,13 +3865,11 @@ class Parser(metaclass=_Parser): args = self._parse_csv(self._parse_conjunction) index = self._index - if not self._match(TokenType.R_PAREN): + if not self._match(TokenType.R_PAREN) and args: # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) - return self.expression( - exp.GroupConcat, - this=seq_get(args, 0), - separator=self._parse_order(this=seq_get(args, 1)), - ) + # bigquery: STRING_AGG([DISTINCT] expression [, separator] [ORDER BY key [{ASC | DESC}] [, ... ]] [LIMIT n]) + args[-1] = self._parse_limit(this=self._parse_order(this=args[-1])) + return self.expression(exp.GroupConcat, this=args[0], separator=seq_get(args, 1)) # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]). # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that @@ -4172,7 +4193,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() - return self.expression( + window = self.expression( exp.Window, this=this, partition_by=partition, @@ -4183,6 +4204,12 @@ class Parser(metaclass=_Parser): first=first, ) + # This covers Oracle's FIRST/LAST syntax: aggregate KEEP (...) OVER (...) + if self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS, advance=False): + return self._parse_window(window, alias=alias) + + return window + def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: self._match(TokenType.BETWEEN) @@ -4276,19 +4303,19 @@ class Parser(metaclass=_Parser): def _parse_null(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NULL): return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) - return None + return self._parse_placeholder() def _parse_boolean(self) -> t.Optional[exp.Expression]: if self._match(TokenType.TRUE): return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) - return None + return self._parse_placeholder() def _parse_star(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STAR): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) - return None + return self._parse_placeholder() def _parse_parameter(self) -> exp.Parameter: wrapped = self._match(TokenType.L_BRACE) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 12cf0b1..7a3c88b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -31,14 +31,19 @@ class Schema(abc.ABC): table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None, dialect: DialectType = None, + normalize: t.Optional[bool] = None, + match_depth: bool = True, ) -> None: """ Register or update a table. Some implementing classes may require column information to also be provided. + The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. Args: table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + match_depth: whether to enforce that the table must match the schema's depth or not. """ @abc.abstractmethod @@ -47,6 +52,7 @@ class Schema(abc.ABC): table: exp.Table | str, only_visible: bool = False, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> t.List[str]: """ Get the column names for a table. @@ -55,6 +61,7 @@ class Schema(abc.ABC): table: the `Table` expression instance. only_visible: whether to include invisible columns. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. Returns: The list of column names. @@ -66,6 +73,7 @@ class Schema(abc.ABC): table: exp.Table | str, column: exp.Column, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> exp.DataType: """ Get the `sqlglot.exp.DataType` type of a column in the schema. @@ -74,6 +82,7 @@ class Schema(abc.ABC): table: the source table. column: the target column. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. Returns: The resulting column type. @@ -99,7 +108,7 @@ class AbstractMappingSchema(t.Generic[T]): ) -> None: self.mapping = mapping or {} self.mapping_trie = new_trie( - tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) + tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) ) self._supported_table_args: t.Tuple[str, ...] = tuple() @@ -107,13 +116,13 @@ class AbstractMappingSchema(t.Generic[T]): def empty(self) -> bool: return not self.mapping - def _depth(self) -> int: + def depth(self) -> int: return dict_depth(self.mapping) @property def supported_table_args(self) -> t.Tuple[str, ...]: if not self._supported_table_args and self.mapping: - depth = self._depth() + depth = self.depth() if not depth: # None self._supported_table_args = tuple() @@ -191,6 +200,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): self.visible = visible or {} self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + self._depth = 0 super().__init__(self._normalize(schema or {})) @@ -200,6 +210,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): schema=mapping_schema.mapping, visible=mapping_schema.visible, dialect=mapping_schema.dialect, + normalize=mapping_schema.normalize, ) def copy(self, **kwargs) -> MappingSchema: @@ -208,6 +219,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): "schema": self.mapping.copy(), "visible": self.visible.copy(), "dialect": self.dialect, + "normalize": self.normalize, **kwargs, } ) @@ -217,19 +229,30 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None, dialect: DialectType = None, + normalize: t.Optional[bool] = None, + match_depth: bool = True, ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. + The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. Args: table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. dialect: the SQL dialect that will be used to parse `table` if it's a string. + normalize: whether to normalize identifiers according to the dialect of interest. + match_depth: whether to enforce that the table must match the schema's depth or not. """ - normalized_table = self._normalize_table(table, dialect=dialect) + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) + + if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): + raise SchemaError( + f"Table {normalized_table.sql(dialect=self.dialect)} must match the " + f"schema's nesting level: {self.depth()}." + ) normalized_column_mapping = { - self._normalize_name(key, dialect=dialect): value + self._normalize_name(key, dialect=dialect, normalize=normalize): value for key, value in ensure_column_mapping(column_mapping).items() } @@ -247,8 +270,9 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): table: exp.Table | str, only_visible: bool = False, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> t.List[str]: - normalized_table = self._normalize_table(table, dialect=dialect) + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) schema = self.find(normalized_table) if schema is None: @@ -265,11 +289,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): table: exp.Table | str, column: exp.Column, dialect: DialectType = None, + normalize: t.Optional[bool] = None, ) -> exp.DataType: - normalized_table = self._normalize_table(table, dialect=dialect) + normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) normalized_column_name = self._normalize_name( - column if isinstance(column, str) else column.this, dialect=dialect + column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize ) table_schema = self.find(normalized_table, raise_on_missing=False) @@ -293,12 +318,16 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): Returns: The normalized schema mapping. """ + normalized_mapping: t.Dict = {} flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) - normalized_mapping: t.Dict = {} for keys in flattened_schema: columns = nested_get(schema, *zip(keys, keys)) - assert columns is not None + + if not isinstance(columns, dict): + raise SchemaError( + f"Table {'.'.join(keys[:-1])} must match the schema's nesting level: {len(flattened_schema[0])}." + ) normalized_keys = [ self._normalize_name(key, dialect=self.dialect, is_table=True) for key in keys @@ -312,7 +341,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return normalized_mapping - def _normalize_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: + def _normalize_table( + self, + table: exp.Table | str, + dialect: DialectType = None, + normalize: t.Optional[bool] = None, + ) -> exp.Table: normalized_table = exp.maybe_parse( table, into=exp.Table, dialect=dialect or self.dialect, copy=True ) @@ -322,15 +356,24 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): if isinstance(value, (str, exp.Identifier)): normalized_table.set( arg, - exp.to_identifier(self._normalize_name(value, dialect=dialect, is_table=True)), + exp.to_identifier( + self._normalize_name( + value, dialect=dialect, is_table=True, normalize=normalize + ) + ), ) return normalized_table def _normalize_name( - self, name: str | exp.Identifier, dialect: DialectType = None, is_table: bool = False + self, + name: str | exp.Identifier, + dialect: DialectType = None, + is_table: bool = False, + normalize: t.Optional[bool] = None, ) -> str: dialect = dialect or self.dialect + normalize = self.normalize if normalize is None else normalize try: identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) @@ -338,16 +381,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return name if isinstance(name, str) else name.name name = identifier.name - if not self.normalize: + if not normalize: return name # This can be useful for normalize_identifier identifier.meta["is_table"] = is_table return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name - def _depth(self) -> int: - # The columns themselves are a mapping, but we don't want to include those - return super()._depth() - 1 + def depth(self) -> int: + if not self.empty and not self._depth: + # The columns themselves are a mapping, but we don't want to include those + self._depth = super().depth() - 1 + return self._depth def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: """ diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index a19ebaa..729e47f 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -147,6 +147,8 @@ class TokenType(AutoName): VARIANT = auto() OBJECT = auto() INET = auto() + IPADDRESS = auto() + IPPREFIX = auto() ENUM = auto() # keywords diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 1e6cfc8..7c7c2a7 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -100,7 +100,8 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) qualify_filters = expression.args["qualify"].pop().this - for expr in qualify_filters.find_all((exp.Window, exp.Column)): + select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) + for expr in qualify_filters.find_all(select_candidates): if isinstance(expr, exp.Window): alias = find_new_name(expression.named_selects, "_w") expression.select(exp.alias_(expr, alias), copy=False) |