diff options
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/dialect.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 19 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 25 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 7 | ||||
-rw-r--r-- | sqlglot/expressions.py | 50 | ||||
-rw-r--r-- | sqlglot/generator.py | 46 | ||||
-rw-r--r-- | sqlglot/lineage.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/isolate_table_selects.py | 7 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 12 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 37 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 55 | ||||
-rw-r--r-- | sqlglot/parser.py | 58 | ||||
-rw-r--r-- | sqlglot/tokens.py | 2 |
23 files changed, 254 insertions, 154 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4fc93bf..5376dff 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -620,7 +620,16 @@ def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat return self.sql(this) -# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator +def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: + bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) + if bad_args: + self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") + + return self.func( + "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") + ) + + def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] for agg in aggregations: diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d7e5a43..1d8a7fb 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import ( no_properties_sql, no_safe_divide_sql, pivot_column_names, + regexp_extract_sql, rename_func, str_position_sql, str_to_time_sql, @@ -88,19 +89,6 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str: - bad_args = list(filter(expression.args.get, ("position", "occurrence"))) - if bad_args: - self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}") - - return self.func( - "REGEXP_EXTRACT", - expression.args.get("this"), - expression.args.get("expression"), - expression.args.get("group"), - ) - - def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: sql = self.func("TO_JSON", expression.this, expression.args.get("options")) return f"CAST({sql} AS TEXT)" @@ -156,6 +144,9 @@ class DuckDB(Dialect): "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), "STRING_SPLIT": exp.Split.from_arg_list, @@ -227,7 +218,7 @@ class DuckDB(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Properties: no_properties_sql, - exp.RegexpExtract: _regexp_extract_sql, + exp.RegexpExtract: regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 5762efb..f968f6a 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import ( no_recursive_cte_sql, no_safe_divide_sql, no_trycast_sql, + regexp_extract_sql, rename_func, right_to_substring_sql, strposition_to_locate_sql, @@ -230,23 +231,24 @@ class Hive(Dialect): **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, + "COLLECT_SET": exp.SetAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), - "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), + "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( + [ + exp.TimeStrToTime(this=seq_get(args, 0)), + seq_get(args, 1), + ] ), "DATE_SUB": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)), unit=exp.Literal.string("DAY"), ), - "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( - [ - exp.TimeStrToTime(this=seq_get(args, 0)), - seq_get(args, 1), - ] + "DATEDIFF": lambda args: exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), @@ -256,7 +258,9 @@ class Hive(Dialect): "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, - "COLLECT_SET": exp.SetAgg.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), @@ -363,6 +367,7 @@ class Hive(Dialect): exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), + exp.RegexpExtract: regexp_extract_sql, exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.Right: right_to_substring_sql, @@ -422,5 +427,12 @@ class Hive(Dialect): expression = exp.DataType.build("text") elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) + elif expression.is_type("float"): + size_expression = expression.find(exp.DataTypeSize) + if size_expression: + size = int(size_expression.name) + expression = ( + exp.DataType.build("float") if size <= 32 else exp.DataType.build("double") + ) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index bae0e50..e4de934 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -193,6 +193,12 @@ class MySQL(Dialect): TokenType.VALUES, } + CONJUNCTION = { + **parser.Parser.CONJUNCTION, + TokenType.DAMP: exp.And, + TokenType.XOR: exp.Xor, + } + TABLE_ALIAS_TOKENS = ( parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS ) diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 2b77ef9..69da133 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -99,6 +99,9 @@ class Oracle(Dialect): LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False + COLUMN_JOIN_MARKS_SUPPORTED = True + + LIMIT_FETCH = "FETCH" TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -110,6 +113,7 @@ class Oracle(Dialect): exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.VARCHAR: "VARCHAR2", exp.DataType.Type.NVARCHAR: "NVARCHAR2", + exp.DataType.Type.NCHAR: "NCHAR", exp.DataType.Type.TEXT: "CLOB", exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.VARBINARY: "BLOB", @@ -140,15 +144,9 @@ class Oracle(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "FETCH" - def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" - def column_sql(self, expression: exp.Column) -> str: - column = super().column_sql(expression) - return f"{column} (+)" if expression.args.get("join_mark") else column - def xmltable_sql(self, expression: exp.XMLTable) -> str: this = self.sql(expression, "this") passing = self.expressions(expression, key="passing") diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 1721588..7d35c67 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, no_pivot_sql, no_safe_divide_sql, + regexp_extract_sql, rename_func, right_to_substring_sql, struct_extract_sql, @@ -215,6 +216,9 @@ class Presto(Dialect): this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") ), "NOW": exp.CurrentTimestamp.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) @@ -293,6 +297,7 @@ class Presto(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, + exp.RegexpExtract: regexp_extract_sql, exp.Right: right_to_substring_sql, exp.SafeBracket: lambda self, e: self.func( "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 19924cd..715a84c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -223,13 +223,14 @@ class Snowflake(Dialect): "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, + "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, "TO_ARRAY": exp.Array.from_arg_list, - "TO_VARCHAR": exp.ToChar.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, + "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, } @@ -361,12 +362,12 @@ class Snowflake(Dialect): "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), ), + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToStr: lambda self, e: self.func( "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) ), - exp.TimestampTrunc: timestamptrunc_sql, + exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), @@ -390,6 +391,24 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: + # Other dialects don't support all of the following parameters, so we need to + # generate default values as necessary to ensure the transpilation is correct + group = expression.args.get("group") + parameters = expression.args.get("parameters") or (group and exp.Literal.string("c")) + occurrence = expression.args.get("occurrence") or (parameters and exp.Literal.number(1)) + position = expression.args.get("position") or (occurrence and exp.Literal.number(1)) + + return self.func( + "REGEXP_SUBSTR", + expression.this, + expression.expression, + position, + occurrence, + parameters, + group, + ) + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 92bb755..b77c2c0 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -302,6 +302,7 @@ class TSQL(Dialect): "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, + "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, } @@ -469,6 +470,7 @@ class TSQL(Dialect): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True QUERY_HINTS = False + RETURNING_END = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -532,3 +534,8 @@ class TSQL(Dialect): table = expression.args.get("table") table = f"{table} " if table else "" return f"RETURNS {table}{self.sql(expression, 'this')}" + + def returning_sql(self, expression: exp.Returning) -> str: + into = self.sql(expression, "into") + into = self.seg(f"INTO {into}") if into else "" + return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 242e66c..264b8e9 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -878,11 +878,11 @@ class DerivedTable(Expression): return [c.name for c in table_alias.args.get("columns") or []] @property - def selects(self): + def selects(self) -> t.List[Expression]: return self.this.selects if isinstance(self.this, Subqueryable) else [] @property - def named_selects(self): + def named_selects(self) -> t.List[str]: return [select.output_name for select in self.selects] @@ -959,7 +959,7 @@ class Unionable(Expression): class UDTF(DerivedTable, Unionable): @property - def selects(self): + def selects(self) -> t.List[Expression]: alias = self.args.get("alias") return alias.columns if alias else [] @@ -1576,7 +1576,7 @@ class OnConflict(Expression): class Returning(Expression): - arg_types = {"expressions": True} + arg_types = {"expressions": True, "into": False} # https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html @@ -2194,11 +2194,11 @@ class Subqueryable(Unionable): return with_.expressions @property - def selects(self): + def selects(self) -> t.List[Expression]: raise NotImplementedError("Subqueryable objects must implement `selects`") @property - def named_selects(self): + def named_selects(self) -> t.List[str]: raise NotImplementedError("Subqueryable objects must implement `named_selects`") def with_( @@ -2282,7 +2282,6 @@ class Table(Expression): "pivots": False, "hints": False, "system_time": False, - "wrapped": False, } @property @@ -2300,13 +2299,27 @@ class Table(Expression): return self.text("catalog") @property + def selects(self) -> t.List[Expression]: + return [] + + @property + def named_selects(self) -> t.List[str]: + return [] + + @property def parts(self) -> t.List[Identifier]: """Return the parts of a table in order catalog, db, table.""" - return [ - t.cast(Identifier, self.args[part]) - for part in ("catalog", "db", "this") - if self.args.get(part) - ] + parts: t.List[Identifier] = [] + + for arg in ("catalog", "db", "this"): + part = self.args.get(arg) + + if isinstance(part, Identifier): + parts.append(part) + elif isinstance(part, Dot): + parts.extend(part.flatten()) + + return parts # See the TSQL "Querying data in a system-versioned temporal table" page @@ -2390,7 +2403,7 @@ class Union(Subqueryable): return this @property - def named_selects(self): + def named_selects(self) -> t.List[str]: return self.this.unnest().named_selects @property @@ -2398,7 +2411,7 @@ class Union(Subqueryable): return self.this.is_star or self.expression.is_star @property - def selects(self): + def selects(self) -> t.List[Expression]: return self.this.unnest().selects @property @@ -3517,6 +3530,10 @@ class Or(Connector): pass +class Xor(Connector): + pass + + class BitwiseAnd(Binary): pass @@ -4409,6 +4426,7 @@ class RegexpExtract(Func): "expression": True, "position": False, "occurrence": False, + "parameters": False, "group": False, } @@ -5756,7 +5774,9 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str: raise ValueError(f"Cannot parse {table}") return ".".join( - part.sql(dialect=dialect) if not SAFE_IDENTIFIER_RE.match(part.name) else part.name + part.sql(dialect=dialect, identify=True) + if not SAFE_IDENTIFIER_RE.match(part.name) + else part.name for part in table.parts ) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 4ac988f..857eff1 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -155,6 +155,12 @@ class Generator: # Whether or not to generate the limit as TOP <value> instead of LIMIT <value> LIMIT_IS_TOP = False + # Whether or not to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ... + RETURNING_END = True + + # Whether or not to generate the (+) suffix for columns used in old-style join conditions + COLUMN_JOIN_MARKS_SUPPORTED = False + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") @@ -556,7 +562,13 @@ class Generator: return f"{default}CHARACTER SET={self.sql(expression, 'this')}" def column_sql(self, expression: exp.Column) -> str: - return ".".join( + join_mark = " (+)" if expression.args.get("join_mark") else "" + + if join_mark and not self.COLUMN_JOIN_MARKS_SUPPORTED: + join_mark = "" + self.unsupported("Outer join syntax using the (+) operator is not supported.") + + column = ".".join( self.sql(part) for part in ( expression.args.get("catalog"), @@ -567,6 +579,8 @@ class Generator: if part ) + return f"{column}{join_mark}" + def columnposition_sql(self, expression: exp.ColumnPosition) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" @@ -836,8 +850,11 @@ class Generator: limit = self.sql(expression, "limit") tables = self.expressions(expression, key="tables") tables = f" {tables}" if tables else "" - sql = f"DELETE{tables}{this}{using}{where}{returning}{limit}" - return self.prepend_ctes(expression, sql) + if self.RETURNING_END: + expression_sql = f"{this}{using}{where}{returning}{limit}" + else: + expression_sql = f"{returning}{this}{using}{where}{limit}" + return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql}") def drop_sql(self, expression: exp.Drop) -> str: this = self.sql(expression, "this") @@ -887,7 +904,8 @@ class Generator: unique = "UNIQUE " if expression.args.get("unique") else "" primary = "PRIMARY " if expression.args.get("primary") else "" amp = "AMP " if expression.args.get("amp") else "" - name = f"{expression.name} " if expression.name else "" + name = self.sql(expression, "this") + name = f"{name} " if name else "" table = self.sql(expression, "table") table = f"{self.INDEX_ON} {table} " if table else "" using = self.sql(expression, "using") @@ -1134,7 +1152,13 @@ class Generator: expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" conflict = self.sql(expression, "conflict") returning = self.sql(expression, "returning") - sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}{conflict}{returning}" + + if self.RETURNING_END: + expression_sql = f"{expression_sql}{conflict}{returning}" + else: + expression_sql = f"{returning}{expression_sql}{conflict}" + + sql = f"INSERT{alternative}{ignore}{this}{exists}{partition_sql}{where}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1215,8 +1239,7 @@ class Generator: system_time = expression.args.get("system_time") system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - sql = f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" - return f"({sql})" if expression.args.get("wrapped") else sql + return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1276,7 +1299,11 @@ class Generator: where_sql = self.sql(expression, "where") returning = self.sql(expression, "returning") limit = self.sql(expression, "limit") - sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}{limit}" + if self.RETURNING_END: + expression_sql = f"{from_sql}{where_sql}{returning}{limit}" + else: + expression_sql = f"{returning}{from_sql}{where_sql}{limit}" + sql = f"UPDATE {this} SET {set_sql}{expression_sql}" return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: @@ -2016,6 +2043,9 @@ class Generator: def and_sql(self, expression: exp.And) -> str: return self.connector_sql(expression, "AND") + def xor_sql(self, expression: exp.And) -> str: + return self.connector_sql(expression, "XOR") + def connector_sql(self, expression: exp.Connector, op: str) -> str: if not self.pretty: return self.binary(expression, op) diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 04a8073..9f5ae9a 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -104,7 +104,7 @@ def lineage( # Find the specific select clause that is the source of the column we want. # This can either be a specific, named select or a generic `*` clause. select = next( - (select for select in scope.selects if select.alias_or_name == column_name), + (select for select in scope.expression.selects if select.alias_or_name == column_name), exp.Star() if scope.expression.is_star else None, ) diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index cd8ba3b..3134e65 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -85,7 +85,7 @@ def _unique_outputs(scope): grouped_outputs = set() unique_outputs = set() - for select in scope.selects: + for select in scope.expression.selects: output = select.unalias() if output in grouped_expressions: grouped_outputs.add(output) @@ -105,7 +105,7 @@ def _unique_outputs(scope): def _has_single_output_row(scope): return isinstance(scope.expression, exp.Select) and ( - all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects) + all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) or _is_limit_1(scope) or not scope.expression.args.get("from") ) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 84f50e9..5ae1fa0 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -113,7 +113,7 @@ def _eliminate_union(scope, existing_ctes, taken): taken[alias] = scope # Try to maintain the selections - expressions = scope.selects + expressions = scope.expression.selects selects = [ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) for e in expressions diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 79e3ed5..a6524b8 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -12,7 +12,12 @@ def isolate_table_selects(expression, schema=None): continue for _, source in scope.selected_sources.values(): - if not isinstance(source, exp.Table) or not schema.column_names(source): + if ( + not isinstance(source, exp.Table) + or not schema.column_names(source) + or isinstance(source.parent, exp.Subquery) + or isinstance(source.parent.parent, exp.Table) + ): continue if not source.alias: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index e156d5e..6ee057b 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -107,6 +107,7 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) outer_scope.clear_cache() + return expression @@ -166,7 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): if not inner_from: return False inner_from_table = inner_from.alias_or_name - inner_projections = {s.alias_or_name: s for s in inner_scope.selects} + inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects} return any( col.table != inner_from_table for selection in selections diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index d51276f..7b3b2b1 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -59,7 +59,7 @@ def reorder_joins(expression): dag = {name: other_table_names(join) for name, join in joins.items()} parent.set( "joins", - [joins[name] for name in tsort(dag) if name != from_.alias_or_name], + [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], ) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index fb1662d..58b988d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -42,7 +42,10 @@ def pushdown_predicates(expression): # so we limit the selected sources to only itself for join in select.args.get("joins") or []: name = join.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) + if name in scope.selected_sources: + pushdown( + join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count + ) return expression diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index be3ddb2..97e8ff6 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -48,12 +48,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) left, right = scope.union_scopes referenced_columns[left] = parent_selections - if any(select.is_star for select in right.selects): + if any(select.is_star for select in right.expression.selects): referenced_columns[right] = parent_selections - elif not any(select.is_star for select in left.selects): + elif not any(select.is_star for select in left.expression.selects): referenced_columns[right] = [ - right.selects[i].alias_or_name - for i, select in enumerate(left.selects) + right.expression.selects[i].alias_or_name + for i, select in enumerate(left.expression.selects) if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections ] @@ -90,7 +90,7 @@ def _remove_unused_selections(scope, parent_selections, schema): removed = False star = False - for selection in scope.selects: + for selection in scope.expression.selects: name = selection.alias_or_name if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 435585c..7972b2b 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -192,13 +192,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if table and (not alias_expr or double_agg): column.set("table", table) elif not column.table and alias_expr and not double_agg: - if isinstance(alias_expr, exp.Literal): + if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): if literal_index: column.replace(exp.Literal.number(i)) else: column.replace(alias_expr.copy()) - for i, projection in enumerate(scope.selects): + for i, projection in enumerate(scope.expression.selects): replace_columns(projection) if isinstance(projection, exp.Alias): @@ -239,7 +239,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver): ordered.set("this", new_expression) if scope.expression.args.get("group"): - selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} + selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} for ordered in ordereds: ordered = ordered.this @@ -270,7 +270,7 @@ def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: try: - return scope.selects[int(node.this) - 1].assert_is(exp.Alias) + return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) except IndexError: raise OptimizeError(f"Unknown output column: {node.name}") @@ -347,7 +347,7 @@ def _expand_stars( if not pivot_output_columns: pivot_output_columns = [col.alias_or_name for col in pivot.expressions] - for expression in scope.selects: + for expression in scope.expression.selects: if isinstance(expression, exp.Star): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) @@ -446,7 +446,7 @@ def _qualify_outputs(scope: Scope): new_selections = [] for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.selects, scope.outer_column_list) + itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): if isinstance(selection, exp.Subquery): if not selection.output_name: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index af8c716..31c9cc0 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -15,7 +15,8 @@ def qualify_tables( schema: t.Optional[Schema] = None, ) -> E: """ - Rewrite sqlglot AST to have fully qualified, unnested tables. + Rewrite sqlglot AST to have fully qualified tables. Join constructs such as + (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. Examples: >>> import sqlglot @@ -23,18 +24,9 @@ def qualify_tables( >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> - >>> expression = sqlglot.parse_one("SELECT * FROM (tbl)") + >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() - 'SELECT * FROM tbl AS tbl' - >>> - >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") - >>> qualify_tables(expression).sql() - 'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2' - - Note: - This rule effectively enforces a left-to-right join order, since all joins - are unnested. This means that the optimizer doesn't necessarily preserve the - original join order, e.g. when parentheses are used to specify it explicitly. + 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' Args: expression: Expression to qualify @@ -49,6 +41,13 @@ def qualify_tables( for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): + if isinstance(derived_table, exp.Subquery): + unnested = derived_table.unnest() + if isinstance(unnested, exp.Table): + joins = unnested.args.pop("joins", None) + derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) + derived_table.this.set("joins", joins) + if not derived_table.args.get("alias"): alias_ = next_alias_name() derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) @@ -66,19 +65,9 @@ def qualify_tables( if not source.args.get("catalog"): source.set("catalog", exp.to_identifier(catalog)) - # Unnest joins attached in tables by appending them to the closest query - for join in source.args.get("joins") or []: - scope.expression.append("joins", join) - - source.set("joins", None) - source.set("wrapped", None) - if not source.alias: - source = source.replace( - alias( - source, name or source.name or next_alias_name(), copy=True, table=True - ) - ) + # Mutates the source by attaching an alias to it + alias(source, name or source.name or next_alias_name(), copy=False, table=True) pivots = source.args.get("pivots") if pivots and not pivots[0].alias: diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 7dcfb37..b2b4230 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -122,7 +122,11 @@ class Scope: self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) - elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): + elif ( + isinstance(node, exp.Subquery) + and isinstance(parent, (exp.From, exp.Join)) + and _is_subquery_scope(node) + ): self._derived_tables.append(node) elif isinstance(node, exp.Subqueryable): self._subqueries.append(node) @@ -274,6 +278,7 @@ class Scope: not ancestor or column.table or isinstance(ancestor, exp.Select) + or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) or ( isinstance(ancestor, exp.Order) and ( @@ -341,23 +346,6 @@ class Scope: } @property - def selects(self): - """ - Select expressions of this scope. - - For example, for the following expression: - SELECT 1 as a, 2 as b FROM x - - The outputs are the "1 as a" and "2 as b" expressions. - - Returns: - list[exp.Expression]: expressions - """ - if isinstance(self.expression, exp.Union): - return self.expression.unnest().selects - return self.expression.selects - - @property def external_columns(self): """ Columns that appear to reference sources in outer scopes. @@ -548,6 +536,8 @@ def _traverse_scope(scope): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) + elif isinstance(scope.expression, exp.Table): + yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): pass else: @@ -620,6 +610,15 @@ def _traverse_ctes(scope): scope.sources.update(sources) +def _is_subquery_scope(expression: exp.Subquery) -> bool: + """ + We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope. + If an alias is present, it shadows all names under the Subquery, so that's an + exception to this rule. + """ + return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias) + + def _traverse_tables(scope): sources = {} @@ -629,9 +628,8 @@ def _traverse_tables(scope): if from_: expressions.append(from_.this) - for expression in (scope.expression, *scope.find_all(exp.Table)): - for join in expression.args.get("joins") or []: - expressions.append(join.this) + for join in scope.expression.args.get("joins") or []: + expressions.append(join.this) if isinstance(scope.expression, exp.Table): expressions.append(scope.expression) @@ -655,6 +653,8 @@ def _traverse_tables(scope): sources[find_new_name(sources, table_name)] = expression else: sources[source_name] = expression + + expressions.extend(join.this for join in expression.args.get("joins") or []) continue if not isinstance(expression, exp.DerivedTable): @@ -664,10 +664,15 @@ def _traverse_tables(scope): lateral_sources = sources scope_type = ScopeType.UDTF scopes = scope.udtf_scopes - else: + elif _is_subquery_scope(expression): lateral_sources = None scope_type = ScopeType.DERIVED_TABLE scopes = scope.derived_table_scopes + else: + # Makes sure we check for possible sources in nested table constructs + expressions.append(expression.this) + expressions.extend(join.this for join in expression.args.get("joins") or []) + continue for child_scope in _traverse_scope( scope.branch( @@ -728,7 +733,11 @@ def walk_in_scope(expression, bfs=True): continue if ( isinstance(node, exp.CTE) - or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) + or ( + isinstance(node, exp.Subquery) + and isinstance(parent, (exp.From, exp.Join)) + and _is_subquery_scope(node) + ) or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c7f4050..508a273 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -1708,6 +1708,8 @@ class Parser(metaclass=_Parser): self._match(TokenType.TABLE) this = self._parse_table(schema=True) + returning = self._parse_returning() + return self.expression( exp.Insert, this=this, @@ -1717,7 +1719,7 @@ class Parser(metaclass=_Parser): and self._parse_conjunction(), expression=self._parse_ddl_select(), conflict=self._parse_on_conflict(), - returning=self._parse_returning(), + returning=returning or self._parse_returning(), overwrite=overwrite, alternative=alternative, ignore=ignore, @@ -1761,8 +1763,11 @@ class Parser(metaclass=_Parser): def _parse_returning(self) -> t.Optional[exp.Returning]: if not self._match(TokenType.RETURNING): return None - - return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column)) + return self.expression( + exp.Returning, + expressions=self._parse_csv(self._parse_expression), + into=self._match(TokenType.INTO) and self._parse_table_part(), + ) def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: if not self._match(TokenType.FORMAT): @@ -1824,25 +1829,30 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.FROM, advance=False): tables = self._parse_csv(self._parse_table) or None + returning = self._parse_returning() + return self.expression( exp.Delete, tables=tables, this=self._match(TokenType.FROM) and self._parse_table(joins=True), using=self._match(TokenType.USING) and self._parse_table(joins=True), where=self._parse_where(), - returning=self._parse_returning(), + returning=returning or self._parse_returning(), limit=self._parse_limit(), ) def _parse_update(self) -> exp.Update: + this = self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS) + expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) + returning = self._parse_returning() return self.expression( exp.Update, **{ # type: ignore - "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), + "this": this, + "expressions": expressions, "from": self._parse_from(joins=True), "where": self._parse_where(), - "returning": self._parse_returning(), + "returning": returning or self._parse_returning(), "limit": self._parse_limit(), }, ) @@ -1969,31 +1979,9 @@ class Parser(metaclass=_Parser): self._match_r_paren() - alias = None - - # Ensure "wrapped" tables are not parsed as Subqueries. The exception to this is when there's - # an alias that can be applied to the parentheses, because that would shadow all wrapped table - # names, and so we want to parse it as a Subquery to represent the inner scope appropriately. - # Additionally, we want the node under the Subquery to be an actual query, so we will replace - # the table reference with a star query that selects from it. - if isinstance(this, exp.Table): - alias = self._parse_table_alias() - if not alias: - this.set("wrapped", True) - return this - - this.set("wrapped", None) - joins = this.args.pop("joins", None) - this = this.replace(exp.select("*").from_(this.copy(), copy=False)) - this.set("joins", joins) - - subquery = self._parse_subquery(this, parse_alias=parse_subquery_alias and not alias) - if subquery and alias: - subquery.set("alias", alias) - # We return early here so that the UNION isn't attached to the subquery by the # following call to _parse_set_operations, but instead becomes the parent node - return subquery + return self._parse_subquery(this, parse_alias=parse_subquery_alias) elif self._match(TokenType.VALUES): this = self.expression( exp.Values, @@ -3086,7 +3074,13 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): this = exp.DataType( this=exp.DataType.Type.ARRAY, - expressions=[exp.DataType.build(type_token.value, expressions=expressions)], + expressions=[ + exp.DataType( + this=exp.DataType.Type[type_token.value], + expressions=expressions, + nested=nested, + ) + ], nested=True, ) @@ -3147,7 +3141,7 @@ class Parser(metaclass=_Parser): return value return exp.DataType( - this=exp.DataType.Type[type_token.value.upper()], + this=exp.DataType.Type[type_token.value], expressions=expressions, nested=nested, values=values, diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 999bde2..ed14594 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -52,6 +52,7 @@ class TokenType(AutoName): PARAMETER = auto() SESSION_PARAMETER = auto() DAMP = auto() + XOR = auto() BLOCK_START = auto() BLOCK_END = auto() @@ -590,6 +591,7 @@ class Tokenizer(metaclass=_Tokenizer): "OFFSET": TokenType.OFFSET, "ON": TokenType.ON, "OR": TokenType.OR, + "XOR": TokenType.XOR, "ORDER BY": TokenType.ORDER_BY, "ORDINALITY": TokenType.ORDINALITY, "OUTER": TokenType.OUTER, |