diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 50 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 29 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 44 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 28 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 29 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/tableau.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 1 |
17 files changed, 120 insertions, 89 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 9068235..c0191b2 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -39,24 +39,31 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va alias = expression.args.get("alias") - structs = [ - exp.Struct( + return self.unnest_sql( + exp.Unnest( expressions=[ - exp.alias_(value, column_name) - for value, column_name in zip( - t.expressions, - ( - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(t.expressions))) + exp.array( + *( + exp.Struct( + expressions=[ + exp.alias_(value, column_name) + for value, column_name in zip( + t.expressions, + ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(t.expressions))) + ), + ) + ] + ) + for t in expression.find_all(exp.Tuple) ), + copy=False, ) ] ) - for t in expression.find_all(exp.Tuple) - ] - - return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)])) + ) def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: @@ -161,12 +168,18 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: return expression -def _parse_timestamp(args: t.List) -> exp.StrToTime: +def _parse_parse_timestamp(args: t.List) -> exp.StrToTime: this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) this.set("zone", seq_get(args, 2)) return this +def _parse_timestamp(args: t.List) -> exp.Timestamp: + timestamp = exp.Timestamp.from_arg_list(args) + timestamp.set("with_tz", True) + return timestamp + + def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts: expr_type = exp.DateFromParts if len(args) == 3 else exp.Date return expr_type.from_arg_list(args) @@ -318,6 +331,7 @@ class BigQuery(Dialect): "TIMESTAMP": TokenType.TIMESTAMPTZ, } KEYWORDS.pop("DIV") + KEYWORDS.pop("VALUES") class Parser(parser.Parser): PREFIXED_PIVOT_COLUMNS = True @@ -348,7 +362,7 @@ class BigQuery(Dialect): "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), - "PARSE_TIMESTAMP": _parse_timestamp, + "PARSE_TIMESTAMP": _parse_parse_timestamp, "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -367,6 +381,7 @@ class BigQuery(Dialect): "TIME": _parse_time, "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP": _parse_timestamp, "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( @@ -395,11 +410,6 @@ class BigQuery(Dialect): TokenType.TABLE, } - ID_VAR_TOKENS = { - *parser.Parser.ID_VAR_TOKENS, - TokenType.VALUES, - } - PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, "NOT DETERMINISTIC": lambda self: self.expression( diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 1ec15c5..d7be64c 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -93,6 +93,7 @@ class ClickHouse(Dialect): "IPV6": TokenType.IPV6, "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION, "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION, + "SYSTEM": TokenType.COMMAND, } SINGLE_TOKENS = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6e2d190..0440a99 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -654,28 +654,6 @@ def time_format( return _time_format -def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: - """ - In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the - PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding - columns are removed from the create statement. - """ - has_schema = isinstance(expression.this, exp.Schema) - is_partitionable = expression.args.get("kind") in ("TABLE", "VIEW") - - if has_schema and is_partitionable: - prop = expression.find(exp.PartitionedByProperty) - if prop and prop.this and not isinstance(prop.this, exp.Schema): - schema = expression.this - columns = {v.name.upper() for v in prop.this.expressions} - partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set("expressions", [e for e in schema.expressions if e not in partitions]) - prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) - expression.set("this", schema) - - return self.create_sql(expression) - - def parse_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: @@ -742,7 +720,10 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: if not expression.expression: - return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) + from sqlglot.optimizer.annotate_types import annotate_types + + target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP + return self.sql(exp.cast(expression.this, to=target_type)) if expression.text("expression").lower() in TIMEZONES: return self.sql( exp.AtTimeZone( @@ -750,7 +731,7 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: zone=expression.expression, ) ) - return self.function_fallback_sql(expression) + return self.func("TIMESTAMP", expression.this, expression.expression) def locate_to_strposition(args: t.List) -> exp.Expression: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index be23355..409e260 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, - create_with_partitions_sql, datestrtodate_sql, format_time_lambda, no_trycast_sql, @@ -13,6 +12,7 @@ from sqlglot.dialects.dialect import ( str_position_sql, timestrtotime_sql, ) +from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: @@ -125,7 +125,7 @@ class Drill(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), - exp.Create: create_with_partitions_sql, + exp.Create: preprocess([move_schema_columns_to_partitioned_by]), exp.DateAdd: _date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 6337ffd..b1540bb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -9,7 +9,6 @@ from sqlglot.dialects.dialect import ( NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, - create_with_partitions_sql, datestrtodate_sql, format_time_lambda, if_sql, @@ -32,6 +31,12 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, var_map_sql, ) +from sqlglot.transforms import ( + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + preprocess, + move_schema_columns_to_partitioned_by, +) from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -55,30 +60,6 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _create_sql(self, expression: exp.Create) -> str: - # remove UNIQUE column constraints - for constraint in expression.find_all(exp.UniqueColumnConstraint): - if constraint.parent: - constraint.parent.pop() - - properties = expression.args.get("properties") - temporary = any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - - # CTAS with temp tables map to CREATE TEMPORARY VIEW - kind = expression.args["kind"] - if kind.upper() == "TABLE" and temporary: - if expression.expression: - return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" - else: - # CREATE TEMPORARY TABLE may require storage provider - expression = self.temporary_storage_provider(expression) - - return create_with_partitions_sql(self, expression) - - def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str: if isinstance(expression, exp.TsOrDsAdd) and not expression.unit: return self.func("DATE_ADD", expression.this, expression.expression) @@ -285,6 +266,7 @@ class Hive(Dialect): class Parser(parser.Parser): LOG_DEFAULTS_TO_LN = True STRICT_CAST = False + VALUES_FOLLOWED_BY_PAREN = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -518,7 +500,13 @@ class Hive(Dialect): "" if e.args.get("allow_null") else "NOT NULL" ), exp.VarMap: var_map_sql, - exp.Create: _create_sql, + exp.Create: preprocess( + [ + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, + ] + ), exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpExtract: regexp_extract_sql, @@ -581,10 +569,6 @@ class Hive(Dialect): return super()._jsonpathkey_sql(expression) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # Hive has no temporary storage provider (there are hive settings though) - return expression - def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 661ef7d..97c891d 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -445,6 +445,7 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True STRING_ALIASES = True + VALUES_FOLLOWED_BY_PAREN = False def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 0c0d750..de693b9 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -88,6 +88,7 @@ class Oracle(Dialect): class Parser(parser.Parser): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} + VALUES_FOLLOWED_BY_PAREN = False FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 68e2c6d..126261e 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -244,6 +244,8 @@ class Postgres(Dialect): "@@": TokenType.DAT, "@>": TokenType.AT_GT, "<@": TokenType.LT_AT, + "|/": TokenType.PIPE_SLASH, + "||/": TokenType.DPIPE_SLASH, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 609103e..1e0e7e9 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -225,6 +225,8 @@ class Presto(Dialect): } class Parser(parser.Parser): + VALUES_FOLLOWED_BY_PAREN = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARBITRARY": exp.AnyValue.from_arg_list, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a64c1d4..135ffc6 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -136,11 +136,11 @@ class Redshift(Postgres): refs.add( ( this.args["from"] if i == 0 else this.args["joins"][i - 1] - ).alias_or_name.lower() + ).this.alias.lower() ) - table = join.this - if isinstance(table, exp.Table): + table = join.this + if isinstance(table, exp.Table) and not join.args.get("on"): if table.parts[0].name.lower() in refs: table.replace(table.to_column()) return this @@ -158,6 +158,7 @@ class Redshift(Postgres): "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, } + KEYWORDS.pop("VALUES") # Redshift allows # to appear as a table identifier prefix SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy() diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 37f9761..b4275ea 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -477,6 +477,8 @@ class Snowflake(Dialect): "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "COLUMNS": _show_parser("COLUMNS"), + "USERS": _show_parser("USERS"), + "TERSE USERS": _show_parser("USERS"), } STAGED_FILE_SINGLE_TOKENS = { diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 44bd12d..c662ab5 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -5,8 +5,14 @@ import typing as t from sqlglot import exp from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.hive import _parse_ignore_nulls -from sqlglot.dialects.spark2 import Spark2 +from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get +from sqlglot.transforms import ( + ctas_with_tmp_tables_to_create_tmp_view, + remove_unique_constraints, + preprocess, + move_partitioned_by_to_schema_columns, +) def _parse_datediff(args: t.List) -> exp.Expression: @@ -35,6 +41,15 @@ def _parse_datediff(args: t.List) -> exp.Expression: ) +def _normalize_partition(e: exp.Expression) -> exp.Expression: + """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)""" + if isinstance(e, str): + return exp.to_identifier(e) + if isinstance(e, exp.Literal): + return exp.to_identifier(e.name) + return e + + class Spark(Spark2): class Tokenizer(Spark2.Tokenizer): RAW_STRINGS = [ @@ -72,6 +87,17 @@ class Spark(Spark2): TRANSFORMS = { **Spark2.Generator.TRANSFORMS, + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_partitioned_by_to_schema_columns, + ] + ), + exp.PartitionedByProperty: lambda self, + e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 9378d99..fa55b51 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -13,6 +13,12 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get +from sqlglot.transforms import ( + preprocess, + remove_unique_constraints, + ctas_with_tmp_tables_to_create_tmp_view, + move_schema_columns_to_partitioned_by, +) def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: @@ -95,6 +101,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def temporary_storage_provider(expression: exp.Expression) -> exp.Expression: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + + class Spark2(Hive): class Parser(Hive.Parser): TRIM_PATTERN_FIRST = True @@ -121,7 +134,6 @@ class Spark2(Hive): ), zone=seq_get(args, 1), ), - "IIF": exp.If.from_arg_list, "INT": _parse_as_cast("int"), "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, @@ -193,6 +205,15 @@ class Spark2(Hive): e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Create: preprocess( + [ + remove_unique_constraints, + lambda e: ctas_with_tmp_tables_to_create_tmp_view( + e, temporary_storage_provider + ), + move_schema_columns_to_partitioned_by, + ] + ), exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -251,12 +272,6 @@ class Spark2(Hive): return self.func("STRUCT", *args) - def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: - # spark2, spark, Databricks require a storage provider for temporary tables - provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) - expression.args["properties"].append("expressions", provider) - return expression - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index b292c81..6596c5b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -132,6 +132,7 @@ class SQLite(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql, exp.DateStrToDate: lambda self, e: self.sql(e, "this"), + exp.If: rename_func("IIF"), exp.ILike: no_ilike_sql, exp.JSONExtract: _json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_sql, diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 3795045..e8ff249 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,10 +1,14 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, transforms +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, rename_func class Tableau(Dialect): + class Tokenizer(tokens.Tokenizer): + IDENTIFIERS = [("[", "]")] + QUOTES = ["'", '"'] + class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 7f9a11a..5b30cd4 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -74,6 +74,7 @@ class Teradata(Dialect): class Parser(parser.Parser): TABLESAMPLE_CSV = True + VALUES_FOLLOWED_BY_PAREN = False CHARSET_TRANSLATORS = { "GRAPHIC_TO_KANJISJIS", diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 70ea97e..85b2e12 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -457,7 +457,6 @@ class TSQL(Dialect): "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, "HASHBYTES": _parse_hashbytes, - "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract), "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar), |