diff options
Diffstat (limited to '')
30 files changed, 394 insertions, 155 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 29e7c55..133979a 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -148,7 +148,7 @@ def atanh(col: ColumnOrName) -> Column: def cbrt(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "CBRT") + return Column.invoke_expression_over_column(col, expression.Cbrt) def ceil(col: ColumnOrName) -> Column: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index f518ac2..bfc022b 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -70,12 +70,10 @@ class SparkSession: column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} data_expressions = [ - exp.Tuple( - expressions=list( - map( - lambda x: F.lit(x).expression, - row if not isinstance(row, dict) else row.values(), - ) + exp.tuple_( + *map( + lambda x: F.lit(x).expression, + row if not isinstance(row, dict) else row.values(), ) ) for row in data diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 9068235..c0191b2 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -39,24 +39,31 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va alias = expression.args.get("alias") - structs = [ - exp.Struct( + return self.unnest_sql( + exp.Unnest( expressions=[ - exp.alias_(value, column_name) - for value, column_name in zip( - t.expressions, - ( - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(t.expressions))) + exp.array( + *( + exp.Struct( + expressions=[ + exp.alias_(value, column_name) + for value, column_name in zip( + t.expressions, + ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(t.expressions))) + ), + ) + ] + ) + for t in expression.find_all(exp.Tuple) ), + copy=False, ) ] ) - for t in expression.find_all(exp.Tuple) - ] - - return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)])) + ) def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: @@ -161,12 +168,18 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: return expression -def _parse_timestamp(args: t.List) -> exp.StrToTime: +def _parse_parse_timestamp(args: t.List) -> exp.StrToTime: this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) this.set("zone", seq_get(args, 2)) return this +def _parse_timestamp(args: t.List) -> exp.Timestamp: + timestamp = exp.Timestamp.from_arg_list(args) + timestamp.set("with_tz", True) + return timestamp + + def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts: expr_type = exp.DateFromParts if len(args) == 3 else exp.Date return expr_type.from_arg_list(args) @@ -318,6 +331,7 @@ class BigQuery(Dialect): "TIMESTAMP": TokenType.TIMESTAMPTZ, } KEYWORDS.pop("DIV") + KEYWORDS.pop("VALUES") class Parser(parser.Parser): PREFIXED_PIVOT_COLUMNS = True @@ -348,7 +362,7 @@ class BigQuery(Dialect): "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), - "PARSE_TIMESTAMP": _parse_timestamp, + "PARSE_TIMESTAMP": _parse_parse_timestamp, "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -367,6 +381,7 @@ class BigQuery(Dialect): "TIME": _parse_time, "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP": _parse_timestamp, "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( @@ -395,11 +410,6 @@ class BigQuery(Dialect): TokenType.TABLE, } - ID_VAR_TOKENS = { - *parser.Parser.ID_VAR_TOKENS, - TokenType.VALUES, - } - PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, "NOT DETERMINISTIC": lambda self: self.expression( diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 1ec15c5..d7be64c 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -93,6 +93,7 @@ class ClickHouse(Dialect): "IPV6": TokenType.IPV6, "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION, "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION, + "SYSTEM": TokenType.COMMAND, } SINGLE_TOKENS = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6e2d190..0440a99 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -654,28 +654,6 @@ def time_format( return _time_format -def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: - """ - In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the - PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding - columns are removed from the create statement. - """ - has_schema = isinstance(expression.this, exp.Schema) - is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") - - if has_schema and is_partitionable: - prop = expression.find(exp.PartitionedByProperty) - if prop and prop.this and not isinstance(prop.this, exp.Schema): - schema = expression.this - columns = {v.name.upper() for v in prop.this.expressions} - partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set("expressions", [e for e in schema.expressions if e not in partitions]) - prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) - expression.set("this", schema) - - return self.create_sql(expression) - - def parse_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: @@ -742,7 +720,10 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: if not expression.expression: - return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) + from sqlglot.optimizer.annotate_types import annotate_types + + target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP + return self.sql(exp.cast(expression.this, to=target_type)) if expression.text("expression").lower() in TIMEZONES: return self.sql( exp.AtTimeZone( @@ -750,7 +731,7 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: zone=expression.expression, ) ) - return self.function_fallback_sql(expression) + return self.func("TIMESTAMP", expression.this, expression.expression) def locate_to_strposition(args: t.List) -> exp.Expression: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index be23355..409e260 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, - create_with_partitions_sql, datestrtodate_sql, format_time_lambda, no_trycast_sql, @@ -13,6 +12,7 @@ from sqlglot.dialects.dialect import ( str_position_sql, timestrtotime_sql, ) +from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: @@ -125,7 +125,7 @@ class Drill(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), - exp.Create: create_with_partitions_sql, + exp.Create: preprocess([move_schema_columns_to_partitioned_by]), exp.DateAdd: _date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 6337ffd..b1540bb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -9,7 +9,6 @@ from sqlglot.dialects.dialect import ( NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, - create_with_partitions_sql, datestrtodate_sql, format_time_lambda, if_sql, @@ -32,6 +31,12 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, var_map_sql, ) +from sqlglot.transforms import ( + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + preprocess, + move_schema_columns_to_partitioned_by, +) from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -55,30 +60,6 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _create_sql(self, expression: exp.Create) -> str: - # remove UNIQUE column constraints - for constraint in expression.find_all(exp.UniqueColumnConstraint): - if constraint.parent: - constraint.parent.pop() - - properties = expression.args.get("properties") - temporary = any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - - # CTAS with temp tables map to CREATE TEMPORARY VIEW - kind = expression.args["kind"] - if kind.upper() == "TABLE" and temporary: - if expression.expression: - return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" - else: - # CREATE TEMPORARY TABLE may require storage provider - expression = self.temporary_storage_provider(expression) - - return create_with_partitions_sql(self, expression) - - def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str: if isinstance(expression, exp.TsOrDsAdd) and not expression.unit: return self.func("DATE_ADD", expression.this, expression.expression) @@ -285,6 +266,7 @@ class Hive(Dialect): class Parser(parser.Parser): LOG_DEFAULTS_TO_LN = True STRICT_CAST = False + VALUES_FOLLOWED_BY_PAREN = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -518,7 +500,13 @@ class Hive(Dialect): "" if e.args.get("allow_null") else "NOT NULL" ), exp.VarMap: var_map_sql, - exp.Create: _create_sql, + exp.Create: preprocess( + [ + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, + ] + ), exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpExtract: regexp_extract_sql, @@ -581,10 +569,6 @@ class Hive(Dialect): return super()._jsonpathkey_sql(expression) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # Hive has no temporary storage provider (there are hive settings though) - return expression - def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 661ef7d..97c891d 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -445,6 +445,7 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True STRING_ALIASES = True + VALUES_FOLLOWED_BY_PAREN = False def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 0c0d750..de693b9 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -88,6 +88,7 @@ class Oracle(Dialect): class Parser(parser.Parser): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} + VALUES_FOLLOWED_BY_PAREN = False FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 68e2c6d..126261e 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -244,6 +244,8 @@ class Postgres(Dialect): "@@": TokenType.DAT, "@>": TokenType.AT_GT, "<@": TokenType.LT_AT, + "|/": TokenType.PIPE_SLASH, + "||/": TokenType.DPIPE_SLASH, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 609103e..1e0e7e9 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -225,6 +225,8 @@ class Presto(Dialect): } class Parser(parser.Parser): + VALUES_FOLLOWED_BY_PAREN = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARBITRARY": exp.AnyValue.from_arg_list, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a64c1d4..135ffc6 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -136,11 +136,11 @@ class Redshift(Postgres): refs.add( ( this.args["from"] if i == 0 else this.args["joins"][i - 1] - ).alias_or_name.lower() + ).this.alias.lower() ) - table = join.this - if isinstance(table, exp.Table): + table = join.this + if isinstance(table, exp.Table) and not join.args.get("on"): if table.parts[0].name.lower() in refs: table.replace(table.to_column()) return this @@ -158,6 +158,7 @@ class Redshift(Postgres): "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, } + KEYWORDS.pop("VALUES") # Redshift allows # to appear as a table identifier prefix SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy() diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 37f9761..b4275ea 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -477,6 +477,8 @@ class Snowflake(Dialect): "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "COLUMNS": _show_parser("COLUMNS"), + "USERS": _show_parser("USERS"), + "TERSE USERS": _show_parser("USERS"), } STAGED_FILE_SINGLE_TOKENS = { diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 44bd12d..c662ab5 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -5,8 +5,14 @@ import typing as t from sqlglot import exp from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.hive import _parse_ignore_nulls -from sqlglot.dialects.spark2 import Spark2 +from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get +from sqlglot.transforms import ( + ctas_with_tmp_tables_to_create_tmp_view, + remove_unique_constraints, + preprocess, + move_partitioned_by_to_schema_columns, +) def _parse_datediff(args: t.List) -> exp.Expression: @@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression: ) +def _normalize_partition(e: exp.Expression) -> exp.Expression: + """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)""" + if isinstance(e, str): + return exp.to_identifier(e) + if isinstance(e, exp.Literal): + return exp.to_identifier(e.name) + return e + + class Spark(Spark2): class Tokenizer(Spark2.Tokenizer): RAW_STRINGS = [ @@ -72,6 +87,17 @@ class Spark(Spark2): TRANSFORMS = { **Spark2.Generator.TRANSFORMS, + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_partitioned_by_to_schema_columns, + ] + ), + exp.PartitionedByProperty: lambda self, + e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 9378d99..fa55b51 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get +from sqlglot.transforms import ( + preprocess, + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, +) def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: @@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def temporary_storage_provider(expression: exp.Expression) -> exp.Expression: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + + class Spark2(Hive): class Parser(Hive.Parser): TRIM_PATTERN_FIRST = True @@ -121,7 +134,6 @@ class Spark2(Hive): ), zone=seq_get(args, 1), ), - "IIF": exp.If.from_arg_list, "INT": _parse_as_cast("int"), "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, @@ -193,6 +205,15 @@ class Spark2(Hive): e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_schema_columns_to_partitioned_by, + ] + ), exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -251,12 +272,6 @@ class Spark2(Hive): return self.func("STRUCT", *args) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # spark2, spark, Databricks require a storage provider for temporary tables - provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) - expression.args["properties"].append("expressions", provider) - return expression - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index b292c81..6596c5b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -132,6 +132,7 @@ class SQLite(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql, exp.DateStrToDate: lambda self, e: self.sql(e, "this"), + exp.If: rename_func("IIF"), exp.ILike: no_ilike_sql, exp.JSONExtract: _json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_sql, diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 3795045..e8ff249 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,10 +1,14 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, transforms +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, rename_func class Tableau(Dialect): + class Tokenizer(tokens.Tokenizer): + IDENTIFIERS = [("[", "]")] + QUOTES = ["'", '"'] + class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 7f9a11a..5b30cd4 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -74,6 +74,7 @@ class Teradata(Dialect): class Parser(parser.Parser): TABLESAMPLE_CSV = True + VALUES_FOLLOWED_BY_PAREN = False CHARSET_TRANSLATORS = { "GRAPHIC_TO_KANJISJIS", diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 70ea97e..85b2e12 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -457,7 +457,6 @@ class TSQL(Dialect): "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, "HASHBYTES": _parse_hashbytes, - "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract), "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 11ebbaf..8ef750e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1090,6 +1090,11 @@ class Create(DDL): "clone": False, } + @property + def kind(self) -> t.Optional[str]: + kind = self.args.get("kind") + return kind and kind.upper() + # https://docs.snowflake.com/en/sql-reference/sql/create-clone # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement @@ -4626,6 +4631,11 @@ class CountIf(AggFunc): _sql_names = ["COUNT_IF", "COUNTIF"] +# cube root +class Cbrt(Func): + pass + + class CurrentDate(Func): arg_types = {"this": False} @@ -4728,7 +4738,7 @@ class Extract(Func): class Timestamp(Func): - arg_types = {"this": False, "expression": False} + arg_types = {"this": False, "expression": False, "with_tz": False} class TimestampAdd(Func, TimeUnit): @@ -4833,7 +4843,7 @@ class Posexplode(Explode): pass -class PosexplodeOuter(Posexplode): +class PosexplodeOuter(Posexplode, ExplodeOuter): pass @@ -4868,6 +4878,7 @@ class Xor(Connector, Func): class If(Func): arg_types = {"this": True, "true": True, "false": False} + _sql_names = ["IF", "IIF"] class Nullif(Func): @@ -6883,6 +6894,7 @@ def replace_tables( table = to_table( new_name, **{k: v for k, v in node.args.items() if k not in TABLE_PARTS}, + dialect=dialect, ) table.add_comments([original]) return table @@ -7072,6 +7084,60 @@ def cast_unless( return cast(expr, to, **opts) +def array( + *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs +) -> Array: + """ + Returns an array. + + Examples: + >>> array(1, 'x').sql() + 'ARRAY(1, x)' + + Args: + expressions: the expressions to add to the array. + copy: whether or not to copy the argument expressions. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Returns: + An array expression. + """ + return Array( + expressions=[ + maybe_parse(expression, copy=copy, dialect=dialect, **kwargs) + for expression in expressions + ] + ) + + +def tuple_( + *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs +) -> Tuple: + """ + Returns an tuple. + + Examples: + >>> tuple_(1, 'x').sql() + '(1, x)' + + Args: + expressions: the expressions to add to the tuple. + copy: whether or not to copy the argument expressions. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Returns: + A tuple expression. + """ + return Tuple( + expressions=[ + maybe_parse(expression, copy=copy, dialect=dialect, **kwargs) + for expression in expressions + ] + ) + + def true() -> Boolean: """ Returns a true Boolean expression. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 318d782..4ff5a0e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -124,6 +124,7 @@ class Generator(metaclass=_Generator): exp.StabilityProperty: lambda self, e: e.name, exp.TemporaryProperty: lambda self, e: "TEMPORARY", exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", + exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression), exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions), exp.TransientProperty: lambda self, e: "TRANSIENT", @@ -3360,7 +3361,7 @@ class Generator(metaclass=_Generator): return self.sql(arg) cond_for_null = arg.is_(exp.null()) - return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg]))) + return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.array(arg, copy=False))) def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str: this = expression.this diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 9799fe2..35a4586 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -6,7 +6,7 @@ import logging import re import sys import typing as t -from collections.abc import Collection +from collections.abc import Collection, Set from contextlib import contextmanager from copy import copy from enum import Enum @@ -496,3 +496,31 @@ DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: return expression is not None and expression.name.lower() in DATE_UNITS + + +K = t.TypeVar("K") +V = t.TypeVar("V") + + +class SingleValuedMapping(t.Mapping[K, V]): + """ + Mapping where all keys return the same value. + + This rigamarole is meant to avoid copying keys, which was originally intended + as an optimization while qualifying columns for tables with lots of columns. + """ + + def __init__(self, keys: t.Collection[K], value: V): + self._keys = keys if isinstance(keys, Set) else set(keys) + self._value = value + + def __getitem__(self, key: K) -> V: + if key in self._keys: + return self._value + raise KeyError(key) + + def __len__(self) -> int: + return len(self._keys) + + def __iter__(self) -> t.Iterator[K]: + return iter(self._keys) diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index bdd1d14..f10fbb9 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -153,7 +153,7 @@ def lineage( raise ValueError(f"Could not find {column} in {scope.expression}") for s in scope.union_scopes: - to_node(index, scope=s, upstream=upstream) + to_node(index, scope=s, upstream=upstream, alias=alias) return upstream @@ -209,7 +209,11 @@ def lineage( if isinstance(source, Scope): # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. to_node( - c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table) + c.name, + scope=source, + scope_name=table, + upstream=node, + alias=aliases.get(table) or alias, ) else: # The source is not a scope - we've reached the end of the line. At this point, if a source is not found diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index cb9312c..ce274bb 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -204,7 +204,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.TimeAdd, exp.TimeStrToTime, exp.TimeSub, - exp.Timestamp, exp.TimestampAdd, exp.TimestampSub, exp.UnixToTime, @@ -276,6 +275,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), + exp.Timestamp: lambda self, e: self._annotate_with_type( + e, + exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, + ), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index faf18c6..0aa8134 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -38,7 +38,12 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression: if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): return exp.cast(node.this, to=exp.DataType.Type.DATE) if isinstance(node, exp.Timestamp) and not node.expression: - return exp.cast(node.this, to=exp.DataType.Type.TIMESTAMP) + if not node.type: + from sqlglot.optimizer.annotate_types import annotate_types + + node = annotate_types(node) + return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) + return node @@ -76,9 +81,8 @@ def coerce_type(node: exp.Expression) -> exp.Expression: def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: if ( isinstance(expression, exp.Cast) - and expression.to.type and expression.this.type - and expression.to.type.this == expression.this.type.this + and expression.to.this == expression.this.type.this ): return expression.this return expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 1656727..5c27bc3 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -6,7 +6,7 @@ import typing as t from sqlglot import alias, exp from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError -from sqlglot.helper import seq_get +from sqlglot.helper import seq_get, SingleValuedMapping from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -586,8 +586,8 @@ class Resolver: def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): self.scope = scope self.schema = schema - self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None - self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None + self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None + self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None self._all_columns: t.Optional[t.Set[str]] = None self._infer_schema = infer_schema @@ -640,7 +640,7 @@ class Resolver: } return self._all_columns - def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: + def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: """Resolve the source columns for a given source `name`.""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") @@ -662,10 +662,15 @@ class Resolver: else: column_aliases = [] - # If the source's columns are aliased, their aliases shadow the corresponding column names - return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] + if column_aliases: + # If the source's columns are aliased, their aliases shadow the corresponding column names. + # This can be expensive if there are lots of columns, so only do this if column_aliases exist. + return [ + alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases) + ] + return columns - def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: + def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: if self._source_columns is None: self._source_columns = { source_name: self.get_source_columns(source_name) @@ -676,8 +681,8 @@ class Resolver: return self._source_columns def _get_unambiguous_columns( - self, source_columns: t.Dict[str, t.List[str]] - ) -> t.Dict[str, str]: + self, source_columns: t.Dict[str, t.Sequence[str]] + ) -> t.Mapping[str, str]: """ Find all the unambiguous columns in sources. @@ -693,12 +698,17 @@ class Resolver: source_columns_pairs = list(source_columns.items()) first_table, first_columns = source_columns_pairs[0] - unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} + + if len(source_columns_pairs) == 1: + # Performance optimization - avoid copying first_columns if there is only one table. + return SingleValuedMapping(first_columns, first_table) + + unambiguous_columns = {col: first_table for col in first_columns} all_columns = set(unambiguous_columns) for table, columns in source_columns_pairs[1:]: - unique = self._find_unique_columns(columns) - ambiguous = set(all_columns).intersection(unique) + unique = set(columns) + ambiguous = all_columns.intersection(unique) all_columns.update(columns) for column in ambiguous: @@ -707,19 +717,3 @@ class Resolver: unambiguous_columns[column] = table return unambiguous_columns - - @staticmethod - def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: - """ - Find the unique columns in a list of columns. - - Example: - >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) - ['a', 'c'] - - This is necessary because duplicate column names are ambiguous. - """ - counts: t.Dict[str, int] = {} - for column in columns: - counts[column] = counts.get(column, 0) + 1 - return {column for column, count in counts.items() if count == 1} diff --git a/sqlglot/parser.py b/sqlglot/parser.py index dfa3024..25c5789 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -29,8 +29,8 @@ def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: values.append(args[i + 1]) return exp.VarMap( - keys=exp.Array(expressions=keys), - values=exp.Array(expressions=values), + keys=exp.array(*keys, copy=False), + values=exp.array(*values, copy=False), ) @@ -638,6 +638,8 @@ class Parser(metaclass=_Parser): TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()), TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()), TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()), + TokenType.PIPE_SLASH: lambda self: self.expression(exp.Sqrt, this=self._parse_unary()), + TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()), } PRIMARY_PARSERS = { @@ -1000,9 +1002,13 @@ class Parser(metaclass=_Parser): MODIFIERS_ATTACHED_TO_UNION = True UNION_MODIFIERS = {"order", "limit", "offset"} - # parses no parenthesis if statements as commands + # Parses no parenthesis if statements as commands NO_PAREN_IF_COMMANDS = True + # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause. + # If this is True and '(' is not found, the keyword will be treated as an identifier + VALUES_FOLLOWED_BY_PAREN = True + __slots__ = ( "error_level", "error_message_context", @@ -2058,7 +2064,7 @@ class Parser(metaclass=_Parser): partition=self._parse_partition(), where=self._match_pair(TokenType.REPLACE, TokenType.WHERE) and self._parse_conjunction(), - expression=self._parse_ddl_select(), + expression=self._parse_derived_table_values() or self._parse_ddl_select(), conflict=self._parse_on_conflict(), returning=returning or self._parse_returning(), overwrite=overwrite, @@ -2267,8 +2273,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.Tuple, expressions=expressions) - # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. - # https://prestodb.io/docs/current/sql/values.html + # In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows. return self.expression(exp.Tuple, expressions=[self._parse_expression()]) def _parse_projections(self) -> t.List[exp.Expression]: @@ -2367,12 +2372,8 @@ class Parser(metaclass=_Parser): # We return early here so that the UNION isn't attached to the subquery by the # following call to _parse_set_operations, but instead becomes the parent node return self._parse_subquery(this, parse_alias=parse_subquery_alias) - elif self._match(TokenType.VALUES): - this = self.expression( - exp.Values, - expressions=self._parse_csv(self._parse_value), - alias=self._parse_table_alias(), - ) + elif self._match(TokenType.VALUES, advance=False): + this = self._parse_derived_table_values() elif from_: this = exp.select("*").from_(from_.this, copy=False) else: @@ -2969,7 +2970,7 @@ class Parser(metaclass=_Parser): def _parse_derived_table_values(self) -> t.Optional[exp.Values]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) - if not is_derived and not self._match(TokenType.VALUES): + if not is_derived and not self._match_text_seq("VALUES"): return None expressions = self._parse_csv(self._parse_value) @@ -3655,8 +3656,15 @@ class Parser(metaclass=_Parser): def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: interval = parse_interval and self._parse_interval() if interval: - # Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals - while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): + # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals + while True: + index = self._index + self._match(TokenType.PLUS) + + if not self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): + self._retreat(index) + break + interval = self.expression( # type: ignore exp.Add, this=interval, expression=self._parse_interval(match_interval=False) ) @@ -3872,9 +3880,15 @@ class Parser(metaclass=_Parser): def _parse_column_reference(self) -> t.Optional[exp.Expression]: this = self._parse_field() - if isinstance(this, exp.Identifier): - this = self.expression(exp.Column, this=this) - return this + if ( + not this + and self._match(TokenType.VALUES, advance=False) + and self.VALUES_FOLLOWED_BY_PAREN + and (not self._next or self._next.token_type != TokenType.L_PAREN) + ): + this = self._parse_id_var() + + return self.expression(exp.Column, this=this) if isinstance(this, exp.Identifier) else this def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) @@ -5511,7 +5525,7 @@ class Parser(metaclass=_Parser): then = self.expression( exp.Insert, this=self._parse_value(), - expression=self._match(TokenType.VALUES) and self._parse_value(), + expression=self._match_text_seq("VALUES") and self._parse_value(), ) elif self._match(TokenType.UPDATE): expressions = self._parse_star() diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 1fd4025..dbd0caa 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -49,7 +49,7 @@ class Schema(abc.ABC): only_visible: bool = False, dialect: DialectType = None, normalize: t.Optional[bool] = None, - ) -> t.List[str]: + ) -> t.Sequence[str]: """ Get the column names for a table. @@ -60,7 +60,7 @@ class Schema(abc.ABC): normalize: whether to normalize identifiers according to the dialect of interest. Returns: - The list of column names. + The sequence of column names. """ @abc.abstractmethod diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index b064957..2cfcfa6 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -57,6 +57,8 @@ class TokenType(AutoName): AMP = auto() DPIPE = auto() PIPE = auto() + PIPE_SLASH = auto() + DPIPE_SLASH = auto() CARET = auto() TILDA = auto() ARROW = auto() diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index f13569f..4777609 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -213,6 +213,19 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp is_posexplode = isinstance(explode, exp.Posexplode) explode_arg = explode.this + if isinstance(explode, exp.ExplodeOuter): + bracket = explode_arg[0] + bracket.set("safe", True) + bracket.set("offset", True) + explode_arg = exp.func( + "IF", + exp.func( + "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) + ).eq(0), + exp.array(bracket, copy=False), + explode_arg, + ) + # This ensures that we won't use [POS]EXPLODE's argument as a new selection if isinstance(explode_arg, exp.Column): taken_select_names.add(explode_arg.output_name) @@ -466,6 +479,87 @@ def unqualify_columns(expression: exp.Expression) -> exp.Expression: return expression +def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + for constraint in expression.find_all(exp.UniqueColumnConstraint): + if constraint.parent: + constraint.parent.pop() + + return expression + + +def ctas_with_tmp_tables_to_create_tmp_view( + expression: exp.Expression, + tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, +) -> exp.Expression: + assert isinstance(expression, exp.Create) + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + # CTAS with temp tables map to CREATE TEMPORARY VIEW + if expression.kind == "TABLE" and temporary: + if expression.expression: + return exp.Create( + kind="TEMPORARY VIEW", + this=expression.this, + expression=expression.expression, + ) + return tmp_storage_provider(expression) + + return expression + + +def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: + """ + In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the + PARTITIONED BY value is an array of column names, they are transformed into a schema. + The corresponding columns are removed from the create statement. + """ + assert isinstance(expression, exp.Create) + has_schema = isinstance(expression.this, exp.Schema) + is_partitionable = expression.kind in {"TABLE", "VIEW"} + + if has_schema and is_partitionable: + prop = expression.find(exp.PartitionedByProperty) + if prop and prop.this and not isinstance(prop.this, exp.Schema): + schema = expression.this + columns = {v.name.upper() for v in prop.this.expressions} + partitions = [col for col in schema.expressions if col.name.upper() in columns] + schema.set("expressions", [e for e in schema.expressions if e not in partitions]) + prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) + expression.set("this", schema) + + return expression + + +def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. + + Currently, SQLGlot uses the DATASOURCE format for Spark 3. + """ + assert isinstance(expression, exp.Create) + prop = expression.find(exp.PartitionedByProperty) + if ( + prop + and prop.this + and isinstance(prop.this, exp.Schema) + and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions) + ): + prop_this = exp.Tuple( + expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] + ) + schema = expression.this + for e in prop.this.expressions: + schema.append("expressions", e) + prop.set("this", prop_this) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: |