From 684905e3de7854a3806ffa55e0d1a09431ba5a19 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 15 Oct 2022 15:53:00 +0200 Subject: Merging upstream version 7.1.3. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dialects/clickhouse.py | 3 +- sqlglot/dialects/dialect.py | 25 ++--- sqlglot/dialects/duckdb.py | 13 ++- sqlglot/dialects/hive.py | 8 +- sqlglot/dialects/postgres.py | 1 + sqlglot/dialects/presto.py | 6 +- sqlglot/dialects/tableau.py | 5 +- sqlglot/expressions.py | 137 +++++++++++++++++++++++-- sqlglot/generator.py | 124 +++++++++++++++++------ sqlglot/optimizer/eliminate_ctes.py | 42 ++++++++ sqlglot/optimizer/eliminate_joins.py | 160 ++++++++++++++++++++++++++++++ sqlglot/optimizer/eliminate_subqueries.py | 2 +- sqlglot/optimizer/merge_subqueries.py | 18 ++++ sqlglot/optimizer/optimizer.py | 4 + sqlglot/optimizer/pushdown_predicates.py | 19 ++-- sqlglot/optimizer/scope.py | 26 +++++ sqlglot/parser.py | 145 ++++++++++++++++++++++----- sqlglot/planner.py | 36 +------ sqlglot/tokens.py | 15 ++- 20 files changed, 660 insertions(+), 131 deletions(-) create mode 100644 sqlglot/optimizer/eliminate_ctes.py create mode 100644 sqlglot/optimizer/eliminate_joins.py (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 0228bdd..247085b 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -23,7 +23,7 @@ from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.3.1" +__version__ = "7.1.3" pretty = False diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index da5c856..f446e6d 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,7 +1,6 @@ from sqlglot import exp from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.generator import Generator -from sqlglot.helper import csv from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer, TokenType @@ -66,7 +65,7 @@ class ClickHouse(Dialect): TRANSFORMS = { **Generator.TRANSFORMS, exp.Array: inline_array_sql, - exp.StrPosition: lambda self, e: f"position({csv(self.sql(e, 'this'), self.sql(e, 'substr'), self.sql(e, 'position'))})", + exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index f7c6cb5..531c72a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -2,7 +2,7 @@ from enum import Enum from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import csv, list_get +from sqlglot.helper import list_get from sqlglot.parser import Parser from sqlglot.time import format_time from sqlglot.tokens import Tokenizer @@ -177,11 +177,11 @@ class Dialect(metaclass=_Dialect): def rename_func(name): def _rename(self, expression): args = ( - self.expressions(expression, flat=True) + expression.expressions if isinstance(expression, exp.Func) and expression.is_var_len_args - else csv(*[self.sql(e) for e in expression.args.values()]) + else expression.args.values() ) - return f"{name}({args})" + return f"{name}({self.format_args(*args)})" return _rename @@ -189,15 +189,11 @@ def rename_func(name): def approx_count_distinct_sql(self, expression): if expression.args.get("accuracy"): self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") - return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})" + return f"APPROX_COUNT_DISTINCT({self.format_args(expression.this)})" def if_sql(self, expression): - expressions = csv( - self.sql(expression, "this"), - self.sql(expression, "true"), - self.sql(expression, "false"), - ) + expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false")) return f"IF({expressions})" @@ -254,6 +250,11 @@ def no_trycast_sql(self, expression): return self.cast_sql(expression) +def no_properties_sql(self, expression): + self.unsupported("Properties unsupported") + return "" + + def str_position_sql(self, expression): this = self.sql(expression, "this") substr = self.sql(expression, "substr") @@ -275,13 +276,13 @@ def var_map_sql(self, expression): if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): self.unsupported("Cannot convert array columns into map.") - return f"MAP({self.sql(keys)}, {self.sql(values)})" + return f"MAP({self.format_args(keys, values)})" args = [] for key, value in zip(keys.expressions, values.expressions): args.append(self.sql(key)) args.append(self.sql(value)) - return f"MAP({csv(*args)})" + return f"MAP({self.format_args(*args)})" def format_time_lambda(exp_class, dialect, default=None): diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index e09c3dd..f3ff6d3 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, format_time_lambda, no_pivot_sql, + no_properties_sql, no_safe_divide_sql, no_tablesample_sql, rename_func, @@ -68,6 +69,12 @@ def _struct_pack_sql(self, expression): return f"STRUCT_PACK({', '.join(args)})" +def _datatype_sql(self, expression): + if expression.this == exp.DataType.Type.ARRAY: + return f"{self.expressions(expression, flat=True)}[]" + return self.datatype_sql(expression) + + class DuckDB(Dialect): class Tokenizer(Tokenizer): KEYWORDS = { @@ -106,6 +113,8 @@ class DuckDB(Dialect): } class Generator(Generator): + STRUCT_DELIMITER = ("(", ")") + TRANSFORMS = { **Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, @@ -113,8 +122,9 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), + exp.DataType: _datatype_sql, exp.DateAdd: _date_add, - exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""", exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", @@ -124,6 +134,7 @@ class DuckDB(Dialect): exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.Pivot: no_pivot_sql, + exp.Properties: no_properties_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 55d7bcc..8888df8 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -14,7 +14,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.generator import Generator -from sqlglot.helper import csv, list_get +from sqlglot.helper import list_get from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer @@ -32,7 +32,7 @@ def _property_sql(self, expression): def _str_to_unix(self, expression): - return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})" + return f"UNIX_TIMESTAMP({self.format_args(expression.this, _time_format(self, expression))})" def _str_to_date(self, expression): @@ -226,7 +226,7 @@ class Hive(Dialect): exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", - exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", + exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", exp.StrToDate: _str_to_date, exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, @@ -241,7 +241,7 @@ class Hive(Dialect): exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsToDate: _to_date_sql, exp.TryCast: no_trycast_sql, - exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", + exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.format_args(e.this, _time_format(self, e))})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 731e28e..c91ff4b 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -167,6 +167,7 @@ class Postgres(Dialect): **Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, + "COMMENT ON": TokenType.COMMENT_ON, "IDENTITY": TokenType.IDENTITY, "GENERATED": TokenType.GENERATED, "DOUBLE PRECISION": TokenType.DOUBLE, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 85647c5..8dfb2fd 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -11,7 +11,7 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.mysql import MySQL from sqlglot.generator import Generator -from sqlglot.helper import csv, list_get +from sqlglot.helper import list_get from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType @@ -26,7 +26,7 @@ def _concat_ws_sql(self, expression): sep, *args = expression.expressions sep = self.sql(sep) if len(args) > 1: - return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})" + return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})" return f"ARRAY_JOIN({self.sql(args[0])}, {sep})" @@ -66,7 +66,7 @@ def _no_sort_array(self, expression): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: comparator = None - args = csv(self.sql(expression, "this"), comparator) + args = self.format_args(expression.this, comparator) return f"ARRAY_SORT({args})" diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index e571749..45aa041 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,7 +1,6 @@ from sqlglot import exp from sqlglot.dialects.dialect import Dialect from sqlglot.generator import Generator -from sqlglot.helper import list_get from sqlglot.parser import Parser @@ -16,7 +15,7 @@ def _coalesce_sql(self, expression): def _count_sql(self, expression): this = expression.this if isinstance(this, exp.Distinct): - return f"COUNTD({self.sql(this, 'this')})" + return f"COUNTD({self.expressions(this, flat=True)})" return f"COUNT({self.sql(expression, 'this')})" @@ -33,5 +32,5 @@ class Tableau(Dialect): FUNCTIONS = { **Parser.FUNCTIONS, "IFNULL": exp.Coalesce.from_arg_list, - "COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))), + "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index f2ffd12..39f4452 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,3 +1,4 @@ +import datetime import numbers import re from collections import deque @@ -508,7 +509,7 @@ class DerivedTable(Expression): return [select.alias_or_name for select in self.selects] -class Unionable: +class Unionable(Expression): def union(self, expression, distinct=True, dialect=None, **opts): """ Builds a UNION expression. @@ -614,6 +615,10 @@ class Create(Expression): } +class Describe(Expression): + pass + + class UserDefinedFunction(Expression): arg_types = {"this": True, "expressions": False} @@ -741,6 +746,11 @@ class Check(Expression): pass +class Directory(Expression): + # https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html + arg_types = {"this": True, "local": False, "row_format": False} + + class ForeignKey(Expression): arg_types = { "expressions": True, @@ -804,6 +814,18 @@ class Introducer(Expression): arg_types = {"this": True, "expression": True} +class LoadData(Expression): + arg_types = { + "this": True, + "local": False, + "overwrite": False, + "inpath": True, + "partition": False, + "input_format": False, + "serde": False, + } + + class Partition(Expression): pass @@ -1037,6 +1059,18 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} +class RowFormat(Expression): + # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml + arg_types = { + "fields": False, + "escaped": False, + "collection_items": False, + "map_keys": False, + "lines": False, + "null": False, + } + + class Tuple(Expression): arg_types = {"expressions": False} @@ -1071,6 +1105,14 @@ class Subqueryable(Unionable): return [] return with_.expressions + @property + def selects(self): + raise NotImplementedError("Subqueryable objects must implement `selects`") + + @property + def named_selects(self): + raise NotImplementedError("Subqueryable objects must implement `named_selects`") + def with_( self, alias, @@ -1158,7 +1200,7 @@ class Table(Expression): } -class Union(Subqueryable, Expression): +class Union(Subqueryable): arg_types = { "with": False, "this": True, @@ -1169,7 +1211,11 @@ class Union(Subqueryable, Expression): @property def named_selects(self): - return self.args["this"].unnest().named_selects + return self.this.unnest().named_selects + + @property + def selects(self): + return self.this.unnest().selects @property def left(self): @@ -1222,7 +1268,7 @@ class Schema(Expression): arg_types = {"this": False, "expressions": True} -class Select(Subqueryable, Expression): +class Select(Subqueryable): arg_types = { "with": False, "expressions": False, @@ -2075,7 +2121,7 @@ class Bracket(Condition): class Distinct(Expression): - arg_types = {"this": False, "on": False} + arg_types = {"expressions": False, "on": False} class In(Predicate): @@ -2233,6 +2279,14 @@ class Case(Func): class Cast(Func): arg_types = {"this": True, "to": True} + @property + def name(self): + return self.this.name + + @property + def to(self): + return self.args["to"] + class TryCast(Cast): pass @@ -2666,7 +2720,7 @@ def _norm_args(expression): else: arg = _norm_arg(arg) - if arg is not None: + if arg is not None and arg is not False: args[k] = arg return args @@ -3012,6 +3066,30 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): return update +def delete(table, where=None, dialect=None, **opts): + """ + Builds a delete statement. + + Example: + >>> delete("my_table", where="id > 1").sql() + 'DELETE FROM my_table WHERE id > 1' + + Args: + where (str|Condition): sql conditional parsed into a WHERE statement + dialect (str): the dialect used to parse the input expressions. + **opts: other options to use to parse the input expressions. + + Returns: + Delete: the syntax tree for the DELETE statement. + """ + return Delete( + this=maybe_parse(table, into=Table, dialect=dialect, **opts), + where=Where(this=where) + if isinstance(where, Condition) + else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), + ) + + def condition(expression, dialect=None, **opts): """ Initialize a logical condition expression. @@ -3131,6 +3209,25 @@ def to_identifier(alias, quoted=None): return identifier +def to_table(sql_path, **kwargs): + """ + Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. + Example: + >>> to_table('catalog.db.table_name').sql() + 'catalog.db.table_name' + + Args: + sql_path(str): `[catalog].[schema].[table]` string + Returns: + Table: A table expression + """ + table_parts = sql_path.split(".") + catalog, db, table_name = [ + to_identifier(x) if x is not None else x for x in [None] * (3 - len(table_parts)) + table_parts + ] + return Table(this=table_name, db=db, catalog=catalog, **kwargs) + + def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): """ Create an Alias expression. @@ -3216,6 +3313,28 @@ def table_(table, db=None, catalog=None, quoted=None): ) +def values(values, alias=None): + """Build VALUES statement. + + Example: + >>> values([(1, '2')]).sql() + "VALUES (1, '2')" + + Args: + values (list[tuple[str | Expression]]): values statements that will be converted to SQL + alias (str): optional alias + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Values: the Values expression object + """ + return Values( + expressions=[convert(tup) for tup in values], + alias=to_identifier(alias) if alias else None, + ) + + def convert(value): """Convert a python value into an expression object. @@ -3246,6 +3365,12 @@ def convert(value): keys=[convert(k) for k in value.keys()], values=[convert(v) for v in value.values()], ) + if isinstance(value, datetime.datetime): + datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S")) + return TimeStrToTime(this=datetime_literal) + if isinstance(value, datetime.date): + date_literal = Literal.string(value.strftime("%Y-%m-%d")) + return DateStrToDate(this=date_literal) raise ValueError(f"Cannot convert {value}") diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b7e295d..bb7fd71 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2,7 +2,7 @@ import logging from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors -from sqlglot.helper import apply_index_offset, csv, ensure_list +from sqlglot.helper import apply_index_offset, csv from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -43,14 +43,18 @@ class Generator: Default: 3 leading_comma (bool): if the the comma is leading or trailing in select statements Default: False + max_text_width: The max number of characters in a segment before creating new lines in pretty mode. + The default is on the smaller end because the length only represents a segment and not the true + line length. + Default: 80 """ TRANSFORMS = { exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", - exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", - exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", - exp.VarMap: lambda self, e: f"MAP({self.sql(e.args['keys'])}, {self.sql(e.args['values'])})", + exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", + exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", + exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -111,6 +115,7 @@ class Generator: "_replace_backslash", "_escaped_quote_end", "_leading_comma", + "_max_text_width", ) def __init__( @@ -135,6 +140,7 @@ class Generator: null_ordering=None, max_unsupported=3, leading_comma=False, + max_text_width=80, ): import sqlglot @@ -162,6 +168,7 @@ class Generator: self._replace_backslash = self.escape == "\\" self._escaped_quote_end = self.escape + self.quote_end self._leading_comma = leading_comma + self._max_text_width = max_text_width def generate(self, expression): """ @@ -268,7 +275,7 @@ class Generator: raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") def annotation_sql(self, expression): - return self.sql(expression, "expression") + return f"{self.sql(expression, 'expression')} # {expression.name.strip()}" def uncache_sql(self, expression): table = self.sql(expression, "this") @@ -364,6 +371,9 @@ class Generator: ) return self.prepend_ctes(expression, expression_sql) + def describe_sql(self, expression): + return f"DESCRIBE {self.sql(expression, 'this')}" + def prepend_ctes(self, expression, sql): with_ = self.sql(expression, "with") if with_: @@ -405,6 +415,12 @@ class Generator: ) return f"{type_sql}{nested}" + def directory_sql(self, expression): + local = "LOCAL " if expression.args.get("local") else "" + row_format = self.sql(expression, "row_format") + row_format = f" {row_format}" if row_format else "" + return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" + def delete_sql(self, expression): this = self.sql(expression, "this") where_sql = self.sql(expression, "where") @@ -513,13 +529,19 @@ class Generator: return f"{key}={value}" def insert_sql(self, expression): - kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" - this = self.sql(expression, "this") + overwrite = expression.args.get("overwrite") + + if isinstance(expression.this, exp.Directory): + this = "OVERWRITE " if overwrite else "INTO " + else: + this = "OVERWRITE TABLE " if overwrite else "INTO " + + this = f"{this}{self.sql(expression, 'this')}" exists = " IF EXISTS " if expression.args.get("exists") else " " partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else "" expression_sql = self.sql(expression, "expression") sep = self.sep() if partition_sql else "" - sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" + sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression): @@ -534,6 +556,21 @@ class Generator: def introducer_sql(self, expression): return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def rowformat_sql(self, expression): + fields = expression.args.get("fields") + fields = f" FIELDS TERMINATED BY {fields}" if fields else "" + escaped = expression.args.get("escaped") + escaped = f" ESCAPED BY {escaped}" if escaped else "" + items = expression.args.get("collection_items") + items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else "" + keys = expression.args.get("map_keys") + keys = f" MAP KEYS TERMINATED BY {keys}" if keys else "" + lines = expression.args.get("lines") + lines = f" LINES TERMINATED BY {lines}" if lines else "" + null = expression.args.get("null") + null = f" NULL DEFINED AS {null}" if null else "" + return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" + def table_sql(self, expression): table = ".".join( part @@ -688,6 +725,19 @@ class Generator: return f"{self.quote_start}{text}{self.quote_end}" return text + def loaddata_sql(self, expression): + local = " LOCAL" if expression.args.get("local") else "" + inpath = f" INPATH {self.sql(expression, 'inpath')}" + overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" + this = f" INTO TABLE {self.sql(expression, 'this')}" + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + input_format = self.sql(expression, "input_format") + input_format = f" INPUTFORMAT {input_format}" if input_format else "" + serde = self.sql(expression, "serde") + serde = f" SERDE {serde}" if serde else "" + return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" + def null_sql(self, *_): return "NULL" @@ -885,20 +935,24 @@ class Generator: return f"EXISTS{self.wrap(expression)}" def case_sql(self, expression): - this = self.indent(self.sql(expression, "this"), skip_first=True) - this = f" {this}" if this else "" - ifs = [] + this = self.sql(expression, "this") + statements = [f"CASE {this}" if this else "CASE"] for e in expression.args["ifs"]: - ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}")) - ifs.append(self.indent(f"THEN {self.sql(e, 'true')}")) + statements.append(f"WHEN {self.sql(e, 'this')}") + statements.append(f"THEN {self.sql(e, 'true')}") + + default = self.sql(expression, "default") + + if default: + statements.append(f"ELSE {default}") - if expression.args.get("default") is not None: - ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}")) + statements.append("END") - ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs) - statement = f"CASE{this}{ifs}{self.seg('END')}" - return statement + if self.pretty and self.text_width(statements) > self._max_text_width: + return self.indent("\n".join(statements), skip_first=True, skip_last=True) + + return " ".join(statements) def constraint_sql(self, expression): this = self.sql(expression, "this") @@ -970,7 +1024,7 @@ class Generator: return f"REFERENCES {this}({expressions})" def anonymous_sql(self, expression): - args = self.indent(self.expressions(expression, flat=True), skip_first=True, skip_last=True) + args = self.format_args(*expression.expressions) return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" def paren_sql(self, expression): @@ -1008,7 +1062,9 @@ class Generator: if not self.pretty: return self.binary(expression, op) - return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False)) + sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False)) + sep = "\n" if self.text_width(sqls) > self._max_text_width else " " + return f"{sep}{op} ".join(sqls) def bitwiseand_sql(self, expression): return self.binary(expression, "&") @@ -1039,7 +1095,7 @@ class Generator: return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" def distinct_sql(self, expression): - this = self.sql(expression, "this") + this = self.expressions(expression, flat=True) this = f" {this}" if this else "" on = self.sql(expression, "on") @@ -1128,13 +1184,23 @@ class Generator: def function_fallback_sql(self, expression): args = [] - for arg_key in expression.arg_types: - arg_value = ensure_list(expression.args.get(arg_key) or []) - for a in arg_value: - args.append(self.sql(a)) - - args_str = self.indent(", ".join(args), skip_first=True, skip_last=True) - return f"{self.normalize_func(expression.sql_name())}({args_str})" + for arg_value in expression.args.values(): + if isinstance(arg_value, list): + for value in arg_value: + args.append(value) + elif arg_value: + args.append(arg_value) + + return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})" + + def format_args(self, *args): + args = tuple(self.sql(arg) for arg in args if arg is not None) + if self.pretty and self.text_width(args) > self._max_text_width: + return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True) + return ", ".join(args) + + def text_width(self, args): + return sum(len(arg) for arg in args) def format_time(self, expression): return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie) diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py new file mode 100644 index 0000000..7b862c6 --- /dev/null +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -0,0 +1,42 @@ +from sqlglot.optimizer.scope import Scope, build_scope + + +def eliminate_ctes(expression): + """ + Remove unused CTEs from an expression. + + Example: + >>> import sqlglot + >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_ctes(expression).sql() + 'SELECT a FROM z' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + root = build_scope(expression) + + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 + + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py new file mode 100644 index 0000000..0854336 --- /dev/null +++ b/sqlglot/optimizer/eliminate_joins.py @@ -0,0 +1,160 @@ +from sqlglot import expressions as exp +from sqlglot.optimizer.normalize import normalized +from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def eliminate_joins(expression): + """ + Remove unused joins from an expression. + + This only removes joins when we know that the join condition doesn't produce duplicate rows. + + Example: + >>> import sqlglot + >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" + >>> expression = sqlglot.parse_one(sql) + >>> eliminate_joins(expression).sql() + 'SELECT x.a FROM x' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in traverse_scope(expression): + # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. + # It's probably possible to infer this from the outputs of derived tables. + # But for now, let's just skip this rule. + if scope.unqualified_columns: + continue + + joins = scope.expression.args.get("joins", []) + + # Reverse the joins so we can remove chains of unused joins + for join in reversed(joins): + alias = join.this.alias_or_name + if _should_eliminate_join(scope, join, alias): + join.pop() + scope.remove_source(alias) + return expression + + +def _should_eliminate_join(scope, join, alias): + inner_source = scope.sources.get(alias) + return ( + isinstance(inner_source, Scope) + and not _join_is_used(scope, join, alias) + and ( + (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) + or (not join.args.get("on") and _has_single_output_row(inner_source)) + ) + ) + + +def _join_is_used(scope, join, alias): + # We need to find all columns that reference this join. + # But columns in the ON clause shouldn't count. + on = join.args.get("on") + if on: + on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) + else: + on_clause_columns = set() + return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns) + + +def _is_joined_on_all_unique_outputs(scope, join): + unique_outputs = _unique_outputs(scope) + if not unique_outputs: + return False + + _, join_keys, _ = join_condition(join) + remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) + return not remaining_unique_outputs + + +def _unique_outputs(scope): + """Determine output columns of `scope` that must have a unique combination per row""" + if scope.expression.args.get("distinct"): + return set(scope.expression.named_selects) + + group = scope.expression.args.get("group") + if group: + grouped_expressions = set(group.expressions) + grouped_outputs = set() + + unique_outputs = set() + for select in scope.selects: + output = select.unalias() + if output in grouped_expressions: + grouped_outputs.add(output) + unique_outputs.add(select.alias_or_name) + + # All the grouped expressions must be in the output + if not grouped_expressions.difference(grouped_outputs): + return unique_outputs + else: + return set() + + if _has_single_output_row(scope): + return set(scope.expression.named_selects) + + return set() + + +def _has_single_output_row(scope): + return isinstance(scope.expression, exp.Select) and ( + all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects) + or _is_limit_1(scope) + or not scope.expression.args.get("from") + ) + + +def _is_limit_1(scope): + limit = scope.expression.args.get("limit") + return limit and limit.expression.this == "1" + + +def join_condition(join): + """ + Extract the join condition from a join expression. + + Args: + join (exp.Join) + Returns: + tuple[list[str], list[str], exp.Expression]: + Tuple of (source key, join key, remaining predicate) + """ + name = join.this.alias_or_name + on = join.args.get("on") or exp.TRUE + on = on.copy() + source_key = [] + join_key = [] + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + if normalized(on): + for condition in on.flatten() if isinstance(on, exp.And) else [on]: + if isinstance(condition, exp.EQ): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.TRUE) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.TRUE) + + on = simplify(on) + remaining_condition = None if on == exp.TRUE else on + + return source_key, join_key, remaining_condition diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 38e1299..44cdc94 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -8,7 +8,7 @@ from sqlglot.optimizer.simplify import simplify def eliminate_subqueries(expression): """ - Rewrite subqueries as CTES, deduplicating if possible. + Rewrite derived tables as CTES, deduplicating if possible. Example: >>> import sqlglot diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 3e435f5..3c51c18 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -119,6 +119,23 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): Returns: bool: True if can be merged """ + + def _is_a_window_expression_in_unmergable_operation(): + window_expressions = inner_select.find_all(exp.Window) + window_alias_names = {window.parent.alias_or_name for window in window_expressions} + inner_select_name = inner_select.parent.alias_or_name + unmergable_window_columns = [ + column + for column in outer_scope.columns + if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc) + ] + window_expressions_in_unmergable = [ + column + for column in unmergable_window_columns + if column.table == inner_select_name and column.name in window_alias_names + ] + return any(window_expressions_in_unmergable) + return ( isinstance(outer_scope.expression, exp.Select) and isinstance(inner_select, exp.Select) @@ -137,6 +154,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): and inner_select.args.get("where") and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) ) + and not _is_a_window_expression_in_unmergable_operation() ) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 9a09327..2c28ab8 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,3 +1,5 @@ +from sqlglot.optimizer.eliminate_ctes import eliminate_ctes +from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects from sqlglot.optimizer.isolate_table_selects import isolate_table_selects @@ -23,6 +25,8 @@ RULES = ( optimize_joins, eliminate_subqueries, merge_subqueries, + eliminate_joins, + eliminate_ctes, quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 9c8d71d..583d059 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,8 +1,6 @@ -from collections import defaultdict - from sqlglot import exp from sqlglot.optimizer.normalize import normalized -from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.scope import build_scope from sqlglot.optimizer.simplify import simplify @@ -22,15 +20,10 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - scope_ref_count = defaultdict(lambda: 0) - scopes = traverse_scope(expression) - scopes.reverse() - - for scope in scopes: - for _, source in scope.selected_sources.values(): - scope_ref_count[id(source)] += 1 + root = build_scope(expression) + scope_ref_count = root.ref_count() - for scope in scopes: + for scope in reversed(list(root.traverse())): select = scope.expression where = select.args.get("where") if where: @@ -152,9 +145,11 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: + # We can't push down window expressions + has_window_expression = any(select for select in node.selects if select.find(exp.Window)) # we can't push down predicates to select statements if they are referenced in # multiple places. - if not node.args.get("group") and scope_ref_count[id(source)] < 2: + if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression: nodes[table] = node return nodes diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 89de517..68298a0 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +from collections import defaultdict from enum import Enum, auto from sqlglot import exp @@ -314,6 +315,16 @@ class Scope: self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns + @property + def unqualified_columns(self): + """ + Unqualified columns in the current scope. + + Returns: + list[exp.Column]: Unqualified columns + """ + return [c for c in self.columns if not c.table] + @property def join_hints(self): """ @@ -403,6 +414,21 @@ class Scope: yield from child_scope.traverse() yield self + def ref_count(self): + """ + Count the number of times each scope in this tree is referenced. + + Returns: + dict[int, int]: Mapping of Scope instance ID to reference count + """ + scope_ref_count = defaultdict(lambda: 0) + + for scope in self.traverse(): + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + return scope_ref_count + def traverse_scope(expression): """ diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c29e520..b378f12 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -135,11 +135,13 @@ class Parser: TokenType.BOTH, TokenType.BUCKET, TokenType.CACHE, + TokenType.CALL, TokenType.COLLATE, TokenType.COMMIT, TokenType.CONSTRAINT, TokenType.DEFAULT, TokenType.DELETE, + TokenType.DESCRIBE, TokenType.DETERMINISTIC, TokenType.EXECUTE, TokenType.ENGINE, @@ -160,6 +162,7 @@ class Parser: TokenType.LAZY, TokenType.LANGUAGE, TokenType.LEADING, + TokenType.LOCAL, TokenType.LOCATION, TokenType.MATERIALIZED, TokenType.NATURAL, @@ -176,6 +179,7 @@ class Parser: TokenType.REFERENCES, TokenType.RETURNS, TokenType.ROWS, + TokenType.SCHEMA, TokenType.SCHEMA_COMMENT, TokenType.SEED, TokenType.SEMI, @@ -294,6 +298,11 @@ class Parser: COLUMN_OPERATORS = { TokenType.DOT: None, + TokenType.DCOLON: lambda self, this, to: self.expression( + exp.Cast, + this=this, + to=to, + ), TokenType.ARROW: lambda self, this, path: self.expression( exp.JSONExtract, this=this, @@ -342,8 +351,10 @@ class Parser: STATEMENT_PARSERS = { TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.INSERT: lambda self: self._parse_insert(), + TokenType.LOAD_DATA: lambda self: self._parse_load_data(), TokenType.UPDATE: lambda self: self._parse_update(), TokenType.DELETE: lambda self: self._parse_delete(), TokenType.CACHE: lambda self: self._parse_cache(), @@ -449,7 +460,14 @@ class Parser: MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) - CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE} + CREATABLES = { + TokenType.TABLE, + TokenType.VIEW, + TokenType.FUNCTION, + TokenType.INDEX, + TokenType.PROCEDURE, + TokenType.SCHEMA, + } STRICT_CAST = True @@ -650,7 +668,7 @@ class Parser: materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: - self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE") + self.raise_error(f"Expected {self.CREATABLES}") return return self.expression( @@ -677,7 +695,7 @@ class Parser: create_token = self._match_set(self.CREATABLES) and self._prev if not create_token: - self.raise_error("Expected TABLE, VIEW, INDEX, FUNCTION, or PROCEDURE") + self.raise_error(f"Expected {self.CREATABLES}") return exists = self._parse_exists(not_=True) @@ -692,7 +710,7 @@ class Parser: expression = self._parse_select_or_expression() elif create_token.token_type == TokenType.INDEX: this = self._parse_index() - elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): + elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA): this = self._parse_table(schema=True) properties = self._parse_properties() if self._match(TokenType.ALIAS): @@ -836,19 +854,74 @@ class Parser: return self.expression(exp.Properties, expressions=properties) return None + def _parse_describe(self): + self._match(TokenType.TABLE) + + return self.expression(exp.Describe, this=self._parse_id_var()) + def _parse_insert(self): overwrite = self._match(TokenType.OVERWRITE) - self._match(TokenType.INTO) - self._match(TokenType.TABLE) + local = self._match(TokenType.LOCAL) + if self._match_text("DIRECTORY"): + this = self.expression( + exp.Directory, + this=self._parse_var_or_string(), + local=local, + row_format=self._parse_row_format(), + ) + else: + self._match(TokenType.INTO) + self._match(TokenType.TABLE) + this = self._parse_table(schema=True) return self.expression( exp.Insert, - this=self._parse_table(schema=True), + this=this, exists=self._parse_exists(), partition=self._parse_partition(), expression=self._parse_select(nested=True), overwrite=overwrite, ) + def _parse_row_format(self): + if not self._match_pair(TokenType.ROW, TokenType.FORMAT): + return None + + self._match_text("DELIMITED") + + kwargs = {} + + if self._match_text("FIELDS", "TERMINATED", "BY"): + kwargs["fields"] = self._parse_string() + if self._match_text("ESCAPED", "BY"): + kwargs["escaped"] = self._parse_string() + if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"): + kwargs["collection_items"] = self._parse_string() + if self._match_text("MAP", "KEYS", "TERMINATED", "BY"): + kwargs["map_keys"] = self._parse_string() + if self._match_text("LINES", "TERMINATED", "BY"): + kwargs["lines"] = self._parse_string() + if self._match_text("NULL", "DEFINED", "AS"): + kwargs["null"] = self._parse_string() + return self.expression(exp.RowFormat, **kwargs) + + def _parse_load_data(self): + local = self._match(TokenType.LOCAL) + self._match_text("INPATH") + inpath = self._parse_string() + overwrite = self._match(TokenType.OVERWRITE) + self._match_pair(TokenType.INTO, TokenType.TABLE) + + return self.expression( + exp.LoadData, + this=self._parse_table(schema=True), + local=local, + overwrite=overwrite, + inpath=inpath, + partition=self._parse_partition(), + input_format=self._match_text("INPUTFORMAT") and self._parse_string(), + serde=self._match_text("SERDE") and self._parse_string(), + ) + def _parse_delete(self): self._match(TokenType.FROM) @@ -1484,6 +1557,14 @@ class Parser: if self._match_set(self.RANGE_PARSERS): this = self.RANGE_PARSERS[self._prev.token_type](self, this) + elif self._match(TokenType.ISNULL): + this = self.expression(exp.Is, this=this, expression=exp.Null()) + + # Postgres supports ISNULL and NOTNULL for conditions. + # https://blog.andreiavram.ro/postgresql-null-composite-type/ + if self._match(TokenType.NOTNULL): + this = self.expression(exp.Is, this=this, expression=exp.Null()) + this = self.expression(exp.Not, this=this) if negate: this = self.expression(exp.Not, this=this) @@ -1582,12 +1663,6 @@ class Parser: return self._parse_column() return type_token - while self._match(TokenType.DCOLON): - type_token = self._parse_types() - if not type_token: - self.raise_error("Expected type") - this = self.expression(exp.Cast, this=this, to=type_token) - return this def _parse_types(self): @@ -1601,6 +1676,11 @@ class Parser: is_struct = type_token == TokenType.STRUCT expressions = None + if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + return exp.DataType( + this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True + ) + if self._match(TokenType.L_BRACKET): self._retreat(index) return None @@ -1611,7 +1691,7 @@ class Parser: elif nested: expressions = self._parse_csv(self._parse_types) else: - expressions = self._parse_csv(self._parse_type) + expressions = self._parse_csv(self._parse_conjunction) if not expressions: self._retreat(index) @@ -1677,8 +1757,17 @@ class Parser: this = self._parse_bracket(this) while self._match_set(self.COLUMN_OPERATORS): - op = self.COLUMN_OPERATORS.get(self._prev.token_type) - field = self._parse_star() or self._parse_function() or self._parse_id_var() + op_token = self._prev.token_type + op = self.COLUMN_OPERATORS.get(op_token) + + if op_token == TokenType.DCOLON: + field = self._parse_types() + if not field: + self.raise_error("Expected type") + elif op: + field = exp.Literal.string(self._advance() or self._prev.text) + else: + field = self._parse_star() or self._parse_function() or self._parse_id_var() if isinstance(field, exp.Func): # bigquery allows function calls like x.y.count(...) @@ -1687,7 +1776,7 @@ class Parser: this = self._replace_columns_with_dots(this) if op: - this = op(self, this, exp.Literal.string(field.name)) + this = op(self, this, field) elif isinstance(this, exp.Column) and not this.table: this = self.expression(exp.Column, this=field, table=this.this) else: @@ -1808,11 +1897,10 @@ class Parser: if not self._match(TokenType.ARROW): self._retreat(index) - distinct = self._match(TokenType.DISTINCT) - this = self._parse_conjunction() - - if distinct: - this = self.expression(exp.Distinct, this=this) + if self._match(TokenType.DISTINCT): + this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)) + else: + this = self._parse_conjunction() if self._match(TokenType.IGNORE_NULLS): this = self.expression(exp.IgnoreNulls, this=this) @@ -2112,6 +2200,8 @@ class Parser: this = self.expression(exp.Filter, this=this, expression=self._parse_where()) self._match_r_paren() + # T-SQL allows the OVER (...) syntax after WITHIN GROUP. + # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 if self._match(TokenType.WITHIN_GROUP): self._match_l_paren() this = self.expression( @@ -2120,7 +2210,6 @@ class Parser: expression=self._parse_order(), ) self._match_r_paren() - return this # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER # Some dialects choose to implement and some do not. @@ -2366,6 +2455,16 @@ class Parser: if not self._match(TokenType.R_PAREN): self.raise_error("Expecting )") + def _match_text(self, *texts): + index = self._index + for text in texts: + if self._curr and self._curr.text.upper() == text: + self._advance() + else: + self._retreat(index) + return False + return True + def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): exp.replace_children(this, self._replace_columns_with_dots) diff --git a/sqlglot/planner.py b/sqlglot/planner.py index ed0b66c..efabc15 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -3,7 +3,7 @@ import math from sqlglot import alias, exp from sqlglot.errors import UnsupportedError -from sqlglot.optimizer.simplify import simplify +from sqlglot.optimizer.eliminate_joins import join_condition class Plan: @@ -236,40 +236,12 @@ class Join(Step): step = Join() for join in joins: - name = join.this.alias - on = join.args.get("on") or exp.TRUE - source_key = [] - join_key = [] - - # find the join keys - # SELECT - # FROM x - # JOIN y - # ON x.a = y.b AND y.b > 1 - # - # should pull y.b as the join key and x.a as the source key - for condition in on.flatten() if isinstance(on, exp.And) else [on]: - if isinstance(condition, exp.EQ): - left, right = condition.unnest_operands() - left_tables = exp.column_table_names(left) - right_tables = exp.column_table_names(right) - - if name in left_tables and name not in right_tables: - join_key.append(left) - source_key.append(right) - condition.replace(exp.TRUE) - elif name in right_tables and name not in left_tables: - join_key.append(right) - source_key.append(left) - condition.replace(exp.TRUE) - - on = simplify(on) - - step.joins[name] = { + source_key, join_key, condition = join_condition(join) + step.joins[join.this.alias_or_name] = { "side": join.side, "join_key": join_key, "source_key": source_key, - "condition": None if on == exp.TRUE else on, + "condition": condition, } step.add_dependency(Scan.from_expression(join.this, ctes)) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 17c038c..fc8e6e7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -123,6 +123,7 @@ class TokenType(AutoName): CLUSTER_BY = auto() COLLATE = auto() COMMENT = auto() + COMMENT_ON = auto() COMMIT = auto() CONSTRAINT = auto() CREATE = auto() @@ -133,13 +134,14 @@ class TokenType(AutoName): CURRENT_ROW = auto() CURRENT_TIME = auto() CURRENT_TIMESTAMP = auto() - DIV = auto() DEFAULT = auto() DELETE = auto() DESC = auto() + DESCRIBE = auto() DETERMINISTIC = auto() DISTINCT = auto() DISTRIBUTE_BY = auto() + DIV = auto() DROP = auto() ELSE = auto() END = auto() @@ -189,6 +191,8 @@ class TokenType(AutoName): LEFT = auto() LIKE = auto() LIMIT = auto() + LOAD_DATA = auto() + LOCAL = auto() LOCATION = auto() MAP = auto() MATERIALIZED = auto() @@ -196,6 +200,7 @@ class TokenType(AutoName): NATURAL = auto() NEXT = auto() NO_ACTION = auto() + NOTNULL = auto() NULL = auto() NULLS_FIRST = auto() NULLS_LAST = auto() @@ -436,13 +441,14 @@ class Tokenizer(metaclass=_Tokenizer): "CURRENT_DATE": TokenType.CURRENT_DATE, "CURRENT ROW": TokenType.CURRENT_ROW, "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, - "DIV": TokenType.DIV, "DEFAULT": TokenType.DEFAULT, "DELETE": TokenType.DELETE, "DESC": TokenType.DESC, + "DESCRIBE": TokenType.DESCRIBE, "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, + "DIV": TokenType.DIV, "DROP": TokenType.DROP, "ELSE": TokenType.ELSE, "END": TokenType.END, @@ -487,12 +493,15 @@ class Tokenizer(metaclass=_Tokenizer): "LEFT": TokenType.LEFT, "LIKE": TokenType.LIKE, "LIMIT": TokenType.LIMIT, + "LOAD DATA": TokenType.LOAD_DATA, + "LOCAL": TokenType.LOCAL, "LOCATION": TokenType.LOCATION, "MATERIALIZED": TokenType.MATERIALIZED, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NO ACTION": TokenType.NO_ACTION, "NOT": TokenType.NOT, + "NOTNULL": TokenType.NOTNULL, "NULL": TokenType.NULL, "NULLS FIRST": TokenType.NULLS_FIRST, "NULLS LAST": TokenType.NULLS_LAST, @@ -530,6 +539,7 @@ class Tokenizer(metaclass=_Tokenizer): "ROLLUP": TokenType.ROLLUP, "ROW": TokenType.ROW, "ROWS": TokenType.ROWS, + "SCHEMA": TokenType.SCHEMA, "SEED": TokenType.SEED, "SELECT": TokenType.SELECT, "SEMI": TokenType.SEMI, @@ -629,6 +639,7 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.ANALYZE, TokenType.BEGIN, TokenType.CALL, + TokenType.COMMENT_ON, TokenType.COMMIT, TokenType.EXPLAIN, TokenType.OPTIMIZE, -- cgit v1.2.3