diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-26 17:21:54 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-26 17:21:54 +0000 |
commit | c03ba18c491e52cc85d8aae1825dd9e0b4f75e32 (patch) | |
tree | f76d58b50900be4bfd2dc15f0ec38d1a70d8417b /sqlglot | |
parent | Releasing debian version 18.13.0-1. (diff) | |
download | sqlglot-c03ba18c491e52cc85d8aae1825dd9e0b4f75e32.tar.xz sqlglot-c03ba18c491e52cc85d8aae1825dd9e0b4f75e32.zip |
Merging upstream version 18.17.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/__main__.py | 6 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 31 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 56 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 37 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 11 | ||||
-rw-r--r-- | sqlglot/expressions.py | 80 | ||||
-rw-r--r-- | sqlglot/generator.py | 55 | ||||
-rw-r--r-- | sqlglot/lineage.py | 36 | ||||
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 25 | ||||
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 35 | ||||
-rw-r--r-- | sqlglot/parser.py | 66 | ||||
-rw-r--r-- | sqlglot/time.py | 603 | ||||
-rw-r--r-- | sqlglot/tokens.py | 4 | ||||
-rw-r--r-- | sqlglot/transforms.py | 17 |
24 files changed, 1058 insertions, 106 deletions
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index f3433d3..4a2820b 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -58,6 +58,12 @@ parser.add_argument( default="IMMEDIATE", help="IGNORE, WARN, RAISE, IMMEDIATE (default)", ) +parser.add_argument( + "--version", + action="version", + version=sqlglot.__version__, + help="Display the SQLGlot version", +) args = parser.parse_args() diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index d98feee..a424ea4 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -84,11 +84,11 @@ def min(col: ColumnOrName) -> Column: def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAX_BY", ord) + return Column.invoke_expression_over_column(col, expression.ArgMax, expression=ord) def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MIN_BY", ord) + return Column.invoke_expression_over_column(col, expression.ArgMin, expression=ord) def count(col: ColumnOrName) -> Column: @@ -1113,7 +1113,7 @@ def reverse(col: ColumnOrName) -> Column: def flatten(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "FLATTEN") + return Column.invoke_expression_over_column(col, expression.Flatten) def map_keys(col: ColumnOrName) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 7f69dd9..51baba2 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, + arg_max_or_min_no_count, binary_from_function, date_add_interval_sql, datestrtodate_sql, @@ -434,8 +435,13 @@ class BigQuery(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), + exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), + exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), + exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}" + if e.args.get("default") + else f"COLLATE {self.sql(e, 'this')}", exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), @@ -632,6 +638,13 @@ class BigQuery(Dialect): "within", } + def eq_sql(self, expression: exp.EQ) -> str: + # Operands of = cannot be NULL in BigQuery + if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): + return "NULL" + + return self.binary(expression, "=") + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: parent = expression.parent diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index e9d9326..30f728c 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + arg_max_or_min_no_count, inline_array_sql, no_pivot_sql, rename_func, @@ -373,8 +374,11 @@ class ClickHouse(Dialect): exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.AnyValue: rename_func("any"), exp.ApproxDistinct: rename_func("uniq"), + exp.ArgMax: arg_max_or_min_no_count("argMax"), + exp.ArgMin: arg_max_or_min_no_count("argMin"), exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), + exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), exp.DateAdd: lambda self, e: self.func( "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), @@ -418,6 +422,33 @@ class ClickHouse(Dialect): "NAMED COLLECTION", } + def _any_to_has( + self, + expression: exp.EQ | exp.NEQ, + default: t.Callable[[t.Any], str], + prefix: str = "", + ) -> str: + if isinstance(expression.left, exp.Any): + arr = expression.left + this = expression.right + elif isinstance(expression.right, exp.Any): + arr = expression.right + this = expression.left + else: + return default(expression) + return prefix + self.func("has", arr.this.unnest(), this) + + def eq_sql(self, expression: exp.EQ) -> str: + return self._any_to_has(expression, super().eq_sql) + + def neq_sql(self, expression: exp.NEQ) -> str: + return self._any_to_has(expression, super().neq_sql, "NOT ") + + def regexpilike_sql(self, expression: exp.RegexpILike) -> str: + # Manually add a flag to make the search case-insensitive + regex = self.func("CONCAT", "'(?i)'", expression.expression) + return f"match({self.format_args(expression.this, regex)})" + def datatype_sql(self, expression: exp.DataType) -> str: # String is the standard ClickHouse type, every other variant is just an alias. # Additionally, any supplied length parameter will be ignored. diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index bd839af..739e8d7 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -10,7 +10,7 @@ from sqlglot.errors import ParseError from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser -from sqlglot.time import format_time +from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie @@ -595,6 +595,19 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: ) +def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: + if not expression.expression: + return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP)) + if expression.text("expression").lower() in TIMEZONES: + return self.sql( + exp.AtTimeZone( + this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), + zone=expression.expression, + ) + ) + return self.function_fallback_sql(expression) + + def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) @@ -691,9 +704,13 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: _dialect = Dialect.get_or_raise(dialect) time_format = self.format_time(expression) if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): - return self.sql(exp.cast(str_to_time_sql(self, expression), "date")) - - return self.sql(exp.cast(self.sql(expression, "this"), "date")) + return self.sql( + exp.cast( + exp.StrToTime(this=expression.this, format=expression.args["format"]), + "date", + ) + ) + return self.sql(exp.cast(expression.this, "date")) return _ts_or_ds_to_date_sql @@ -725,7 +742,9 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: - bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) + bad_args = list( + filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers")) + ) if bad_args: self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") @@ -756,15 +775,6 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp return names -def simplify_literal(expression: E) -> E: - if not isinstance(expression.expression, exp.Literal): - from sqlglot.optimizer.simplify import simplify - - simplify(expression.expression) - - return expression - - def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -804,3 +814,21 @@ def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str: expression = expression.copy() expression.set("with", expression.expression.args["with"].pop()) return self.insert_sql(expression) + + +def generatedasidentitycolumnconstraint_sql( + self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint +) -> str: + start = self.sql(expression, "start") or "1" + increment = self.sql(expression, "increment") or "1" + return f"IDENTITY({start}, {increment})" + + +def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: + def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: + if expression.args.get("count"): + self.unsupported(f"Only two arguments are supported in function {name}.") + + return self.func(name, expression.this, expression.expression) + + return _arg_max_or_min_sql diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 5b94bcb..287e03a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, + arg_max_or_min_no_count, arrow_json_extract_scalar_sql, arrow_json_extract_sql, binary_from_function, @@ -18,9 +19,9 @@ from sqlglot.dialects.dialect import ( no_comment_column_constraint_sql, no_properties_sql, no_safe_divide_sql, + no_timestamp_sql, pivot_column_names, regexp_extract_sql, - regexp_replace_sql, rename_func, str_position_sql, str_to_time_sql, @@ -172,6 +173,12 @@ class DuckDB(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) ), "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, + "REGEXP_REPLACE": lambda args: exp.RegexpReplace( + this=seq_get(args, 0), + expression=seq_get(args, 1), + replacement=seq_get(args, 2), + modifiers=seq_get(args, 3), + ), "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), "STRING_SPLIT": exp.Split.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, @@ -243,6 +250,8 @@ class DuckDB(Dialect): if e.expressions and e.expressions[0].find(exp.Select) else inline_array_sql(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), + exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), exp.BitwiseXor: rename_func("XOR"), @@ -287,7 +296,13 @@ class DuckDB(Dialect): exp.PercentileDisc: rename_func("QUANTILE_DISC"), exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, - exp.RegexpReplace: regexp_replace_sql, + exp.RegexpReplace: lambda self, e: self.func( + "REGEXP_REPLACE", + e.this, + e.expression, + e.args.get("replacement"), + e.args.get("modifiers"), + ), exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, @@ -298,6 +313,7 @@ class DuckDB(Dialect): exp.StrToTime: str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, + exp.Timestamp: no_timestamp_sql, exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 3f925a7..7bff553 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, + arg_max_or_min_no_count, create_with_partitions_sql, format_time_lambda, if_sql, @@ -106,11 +107,16 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})" return f"({sec_diff}){factor}" if factor else sec_diff - sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" + months_between = unit in DIFF_MONTH_SWITCH + sql_func = "MONTHS_BETWEEN" if months_between else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" + if months_between: + # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part + diff_sql = f"CAST({diff_sql} AS INT)" + return f"{diff_sql}{multiplier_sql}" @@ -426,6 +432,8 @@ class Hive(Dialect): exp.Property: _property_sql, exp.AnyValue: rename_func("FIRST"), exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), + exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySize: rename_func("SIZE"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 59a0a2a..2185a85 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -21,7 +21,6 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, parse_date_delta_with_interval, rename_func, - simplify_literal, strposition_to_locate_sql, ) from sqlglot.helper import seq_get @@ -689,6 +688,8 @@ class MySQL(Dialect): LIMIT_FETCH = "LIMIT" + LIMIT_ONLY_LITERALS = True + # MySQL doesn't support many datatypes in cast. # https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast CAST_MAPPING = { @@ -712,16 +713,6 @@ class MySQL(Dialect): result = f"{result} UNSIGNED" return result - def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: - # MySQL requires simple literal values for its LIMIT clause. - expression = simplify_literal(expression.copy()) - return super().limit_sql(expression, top=top) - - def offset_sql(self, expression: exp.Offset) -> str: - # MySQL requires simple literal values for its OFFSET clause. - expression = simplify_literal(expression.copy()) - return super().offset_sql(expression) - def xor_sql(self, expression: exp.Xor) -> str: if expression.expressions: return self.expressions(expression, sep=" XOR ") diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c435309..086b278 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -20,7 +20,6 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, parse_timestamp_trunc, rename_func, - simplify_literal, str_position_sql, struct_extract_sql, timestamptrunc_sql, @@ -49,7 +48,7 @@ def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | ex this = self.sql(expression, "this") unit = expression.args.get("unit") - expression = simplify_literal(expression).expression + expression = self._simplify_unless_literal(expression.expression) if not isinstance(expression, exp.Literal): self.unsupported("Cannot add non literal") diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 88525a2..aac368c 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, no_pivot_sql, no_safe_divide_sql, + no_timestamp_sql, regexp_extract_sql, rename_func, right_to_substring_sql, @@ -69,9 +70,10 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: if expression.parent: for schema in expression.parent.find_all(exp.Schema): - if isinstance(schema.parent, exp.Property): + column_defs = schema.find_all(exp.ColumnDef) + if column_defs and isinstance(schema.parent, exp.Property): expression = expression.copy() - expression.expressions.extend(schema.expressions) + expression.expressions.extend(column_defs) return self.schema_sql(expression) @@ -252,6 +254,7 @@ class Presto(Dialect): TZ_TO_WITH_TIME_ZONE = True NVL2_SUPPORTED = False STRUCT_DELIMITER = ("(", ")") + LIMIT_ONLY_LITERALS = True PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, @@ -277,6 +280,8 @@ class Presto(Dialect): exp.AnyValue: rename_func("ARBITRARY"), exp.ApproxDistinct: _approx_distinct_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), + exp.ArgMax: rename_func("MAX_BY"), + exp.ArgMin: rename_func("MIN_BY"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), @@ -348,6 +353,7 @@ class Presto(Dialect): exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), + exp.Timestamp: no_timestamp_sql, exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, @@ -367,7 +373,6 @@ class Presto(Dialect): exp.WithinGroup: transforms.preprocess( [transforms.remove_within_group_for_percentiles] ), - exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]), exp.Xor: bool_xor_sql, } @@ -418,3 +423,15 @@ class Presto(Dialect): self.sql(expression, "offset"), self.sql(limit), ] + + def create_sql(self, expression: exp.Create) -> str: + """ + Presto doesn't support CREATE VIEW with expressions (ex: `CREATE VIEW x (cola)` then `(cola)` is the expression), + so we need to remove them + """ + kind = expression.args["kind"] + schema = expression.this + if kind == "VIEW" and schema.expressions: + expression = expression.copy() + expression.this.set("expressions", None) + return super().create_sql(expression) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 04e78a5..df70aa7 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -6,6 +6,7 @@ from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( concat_to_dpipe_sql, concat_ws_to_dpipe_sql, + generatedasidentitycolumnconstraint_sql, rename_func, ts_or_ds_to_date_sql, ) @@ -171,8 +172,10 @@ class Redshift(Postgres): exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.FromBase: rename_func("STRTOL"), + exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, + exp.ParseJSON: rename_func("JSON_PARSE"), exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index fc3e0fa..07be65b 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -262,6 +262,7 @@ class Snowflake(Dialect): ), "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, + "FLATTEN": exp.Explode.from_arg_list, "IFF": exp.If.from_arg_list, "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, @@ -308,6 +309,7 @@ class Snowflake(Dialect): expressions=self._parse_csv(self._parse_id_var), unset=True, ), + "SWAP": lambda self: self._parse_alter_table_swap(), } STATEMENT_PARSERS = { @@ -325,6 +327,22 @@ class Snowflake(Dialect): TokenType.MOD, TokenType.SLASH, } + FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + + def _parse_lateral(self) -> t.Optional[exp.Lateral]: + lateral = super()._parse_lateral() + if not lateral: + return lateral + + if isinstance(lateral.this, exp.Explode): + table_alias = lateral.args.get("alias") + columns = [exp.to_identifier(col) for col in self.FLATTEN_COLUMNS] + if table_alias and not table_alias.args.get("columns"): + table_alias.set("columns", columns) + elif not table_alias: + exp.alias_(lateral, "_flattened", table=columns, copy=False) + + return lateral def _parse_table_parts(self, schema: bool = False) -> exp.Table: # https://docs.snowflake.com/en/user-guide/querying-stage @@ -389,6 +407,10 @@ class Snowflake(Dialect): return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + def _parse_alter_table_swap(self) -> exp.SwapTable: + self._match_text_seq("WITH") + return self.expression(exp.SwapTable, this=self._parse_table(schema=True)) + class Tokenizer(tokens.Tokenizer): STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] @@ -438,6 +460,8 @@ class Snowflake(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.ArgMax: rename_func("MAX_BY"), + exp.ArgMin: rename_func("MIN_BY"), exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), @@ -451,7 +475,10 @@ class Snowflake(Dialect): ), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, + exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.Explode: rename_func("FLATTEN"), exp.Extract: rename_func("DATE_PART"), exp.GenerateSeries: lambda self, e: self.func( "ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step") @@ -520,6 +547,12 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def log_sql(self, expression: exp.Log) -> str: + if not expression.expression: + return self.func("LN", expression.this) + + return super().log_sql(expression) + def unnest_sql(self, expression: exp.Unnest) -> str: selects = ["value"] unnest_alias = expression.args.get("alias") @@ -596,3 +629,7 @@ class Snowflake(Dialect): increment = expression.args.get("increment") increment = f" INCREMENT {increment}" if increment else "" return f"AUTOINCREMENT{start}{increment}" + + def swaptable_sql(self, expression: exp.SwapTable) -> str: + this = self.sql(expression, "this") + return f"SWAP WITH {this}" diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index b9e925a..152afa6 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least +from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func from sqlglot.tokens import TokenType @@ -150,6 +150,7 @@ class Teradata(Dialect): return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) class Generator(generator.Generator): + LIMIT_IS_TOP = True JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False @@ -168,6 +169,8 @@ class Teradata(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.ArgMax: rename_func("MAX_BY"), + exp.ArgMin: rename_func("MIN_BY"), exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 69adb45..867e4e4 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, any_value_to_max_sql, + generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, move_insert_cte_sql, @@ -603,6 +604,7 @@ class TSQL(Dialect): exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.DOUBLE: "FLOAT", exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.TEXT: "VARCHAR(MAX)", exp.DataType.Type.TIMESTAMP: "DATETIME2", exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", exp.DataType.Type.VARIANT: "SQL_VARIANT", @@ -617,6 +619,7 @@ class TSQL(Dialect): exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), + exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), exp.Insert: move_insert_cte_sql, @@ -778,11 +781,3 @@ class TSQL(Dialect): this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True, sep=" ") return f"CONSTRAINT {this} {expressions}" - - # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server - def generatedasidentitycolumnconstraint_sql( - self, expression: exp.GeneratedAsIdentityColumnConstraint - ) -> str: - start = self.sql(expression, "start") or "1" - increment = self.sql(expression, "increment") or "1" - return f"IDENTITY({start}, {increment})" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b94b1e1..5b012b1 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -23,7 +23,7 @@ from enum import auto from functools import reduce from sqlglot._typing import E -from sqlglot.errors import ParseError +from sqlglot.errors import ErrorLevel, ParseError from sqlglot.helper import ( AutoName, camel_to_snake_case, @@ -120,14 +120,14 @@ class Expression(metaclass=_Expression): return hash((self.__class__, self.hashable_args)) @property - def this(self): + def this(self) -> t.Any: """ Retrieves the argument with key "this". """ return self.args.get("this") @property - def expression(self): + def expression(self) -> t.Any: """ Retrieves the argument with key "expression". """ @@ -1235,6 +1235,10 @@ class RenameTable(Expression): pass +class SwapTable(Expression): + pass + + class Comment(Expression): arg_types = {"this": True, "kind": True, "expression": True, "exists": False} @@ -1979,7 +1983,7 @@ class ChecksumProperty(Property): class CollateProperty(Property): - arg_types = {"this": True} + arg_types = {"this": True, "default": False} class CopyGrantsProperty(Property): @@ -2607,11 +2611,11 @@ class Union(Subqueryable): return self.this.unnest().selects @property - def left(self): + def left(self) -> Expression: return self.this @property - def right(self): + def right(self) -> Expression: return self.expression @@ -3700,7 +3704,9 @@ class DataType(Expression): return DataType(this=DataType.Type.UNKNOWN, **kwargs) try: - data_type_exp = parse_one(dtype, read=dialect, into=DataType) + data_type_exp = parse_one( + dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE + ) except ParseError: if udt: return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) @@ -3804,11 +3810,11 @@ class Binary(Condition): arg_types = {"this": True, "expression": True} @property - def left(self): + def left(self) -> Expression: return self.this @property - def right(self): + def right(self) -> Expression: return self.expression @@ -4063,10 +4069,25 @@ class TimeUnit(Expression): arg_types = {"unit": False} + UNABBREVIATED_UNIT_NAME = { + "d": "day", + "h": "hour", + "m": "minute", + "ms": "millisecond", + "ns": "nanosecond", + "q": "quarter", + "s": "second", + "us": "microsecond", + "w": "week", + "y": "year", + } + + VAR_LIKE = (Column, Literal, Var) + def __init__(self, **args): unit = args.get("unit") - if isinstance(unit, (Column, Literal)): - args["unit"] = Var(this=unit.name) + if isinstance(unit, self.VAR_LIKE): + args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name) elif isinstance(unit, Week): unit.set("this", Var(this=unit.this.name)) @@ -4168,6 +4189,24 @@ class Abs(Func): pass +class ArgMax(AggFunc): + arg_types = {"this": True, "expression": True, "count": False} + _sql_names = ["ARG_MAX", "ARGMAX", "MAX_BY"] + + +class ArgMin(AggFunc): + arg_types = {"this": True, "expression": True, "count": False} + _sql_names = ["ARG_MIN", "ARGMIN", "MIN_BY"] + + +class ApproxTopK(AggFunc): + arg_types = {"this": True, "expression": False, "counters": False} + + +class Flatten(Func): + pass + + # https://spark.apache.org/docs/latest/api/sql/index.html#transform class Transform(Func): arg_types = {"this": True, "expression": True} @@ -4540,8 +4579,10 @@ class Exp(Func): pass +# https://docs.snowflake.com/en/sql-reference/functions/flatten class Explode(Func): - pass + arg_types = {"this": True, "expressions": False} + is_var_len_args = True class ExplodeOuter(Explode): @@ -4698,6 +4739,8 @@ class JSONArrayContains(Binary, Predicate, Func): class ParseJSON(Func): # BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE _sql_names = ["PARSE_JSON", "JSON_PARSE"] + arg_types = {"this": True, "expressions": False} + is_var_len_args = True class Least(Func): @@ -4758,6 +4801,16 @@ class Lower(Func): class Map(Func): arg_types = {"keys": False, "values": False} + @property + def keys(self) -> t.List[Expression]: + keys = self.args.get("keys") + return keys.expressions if keys else [] + + @property + def values(self) -> t.List[Expression]: + values = self.args.get("values") + return values.expressions if values else [] + class MapFromEntries(Func): pass @@ -4870,6 +4923,7 @@ class RegexpReplace(Func): "position": False, "occurrence": False, "parameters": False, + "modifiers": False, } @@ -4877,7 +4931,7 @@ class RegexpLike(Binary, Func): arg_types = {"this": True, "expression": True, "flag": False} -class RegexpILike(Func): +class RegexpILike(Binary, Func): arg_types = {"this": True, "expression": True, "flag": False} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b7e26bb..0d6778a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -11,6 +11,9 @@ from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time from sqlglot.tokens import Tokenizer, TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + logger = logging.getLogger("sqlglot") @@ -141,6 +144,9 @@ class Generator: # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" + # Whether or not limit and fetch allows expresions or just limits + LIMIT_ONLY_LITERALS = False + # Whether or not a table is allowed to be renamed with a db RENAME_TABLE_WITH_DB = True @@ -341,6 +347,12 @@ class Generator: exp.With, ) + # Expressions that should not have their comments generated in maybe_comment + EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Binary, + exp.Union, + ) + # Expressions that can remain unwrapped when appearing in the context of an INTERVAL UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = ( exp.Column, @@ -501,7 +513,7 @@ class Generator: else None ) - if not comments or isinstance(expression, exp.Binary): + if not comments or isinstance(expression, self.EXCLUDE_COMMENTS): return sql comments_sql = " ".join( @@ -879,6 +891,10 @@ class Generator: alias = self.sql(expression, "this") columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" + + if not alias and not self.UNNEST_COLUMN_ONLY: + alias = "_t" + return f"{alias}{columns}" def bitstring_sql(self, expression: exp.BitString) -> str: @@ -1611,9 +1627,6 @@ class Generator: def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") - if isinstance(expression.this, exp.Subquery): - return f"LATERAL {this}" - if expression.args.get("view"): alias = expression.args["alias"] columns = self.expressions(alias, key="columns", flat=True) @@ -1629,18 +1642,19 @@ class Generator: def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: this = self.sql(expression, "this") args = ", ".join( - sql - for sql in ( - self.sql(expression, "offset"), - self.sql(expression, "expression"), - ) - if sql + self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e) + for e in (expression.args.get(k) for k in ("offset", "expression")) + if e ) return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") - return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + expression = expression.expression + expression = ( + self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression + ) + return f"{this}{self.seg('OFFSET')} {self.sql(expression)}" def setitem_sql(self, expression: exp.SetItem) -> str: kind = self.sql(expression, "kind") @@ -1895,12 +1909,13 @@ class Generator: def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") - this = f"{this} " if this else "" sql = self.schema_columns_sql(expression) - return f"{this}{sql}" + return f"{this} {sql}" if this and sql else this or sql def schema_columns_sql(self, expression: exp.Schema) -> str: - return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + if expression.expressions: + return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + return "" def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) @@ -2708,8 +2723,8 @@ class Generator: self.unsupported(f"Unsupported property {expression.__class__.__name__}") return f"{property_name} {self.sql(expression, 'this')}" - def set_operation(self, expression: exp.Expression, op: str) -> str: - this = self.sql(expression, "this") + def set_operation(self, expression: exp.Union, op: str) -> str: + this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments) op = self.seg(op) return self.query_modifiers( expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" @@ -2912,6 +2927,14 @@ class Generator: parameters = self.sql(expression, "params_struct") return self.func("PREDICT", model, table, parameters or None) + def _simplify_unless_literal(self, expression: E) -> E: + if not isinstance(expression, exp.Literal): + from sqlglot.optimizer.simplify import simplify + + expression = simplify(expression.copy()) + + return expression + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 113458f..011a6b8 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -112,17 +112,34 @@ def lineage( column if isinstance(column, int) else next( - i - for i, select in enumerate(scope.expression.selects) - if select.alias_or_name == column + ( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column or select.is_star + ), + -1, # mypy will not allow a None here, but a negative index should never be returned ) ) + if index == -1: + raise ValueError(f"Could not find {column} in {scope.expression}") + for s in scope.union_scopes: to_node(index, scope=s, upstream=upstream) return upstream + subquery = select.unalias() + + if isinstance(subquery, exp.Subquery): + upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select) + scope = t.cast(Scope, build_scope(subquery.unnest())) + + for select in subquery.named_selects: + to_node(select, scope=scope, upstream=upstream) + + return upstream + if isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with # a version that has only the column we care about. @@ -142,8 +159,19 @@ def lineage( if upstream: upstream.downstream.append(node) + # if the select is a star add all scope sources as downstreams + if select.is_star: + for source in scope.sources.values(): + node.downstream.append(Node(name=select.sql(), source=source, expression=source)) + # Find all columns that went into creating this one to list their lineage nodes. - for c in set(select.find_all(exp.Column)): + source_columns = set(select.find_all(exp.Column)) + + # If the source is a UDTF find columns used in the UTDF to generate the table + if isinstance(source, exp.UDTF): + source_columns |= set(source.find_all(exp.Column)) + + for c in source_columns: table = c.table source = scope.sources.get(table) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 17af6ac..69d4567 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -6,7 +6,7 @@ import typing as t from sqlglot import exp from sqlglot._typing import E -from sqlglot.helper import ensure_list, subclasses +from sqlglot.helper import ensure_list, seq_get, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema @@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Bracket: lambda self, e: self._annotate_bracket(e), exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), @@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), @@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, datatype) return expression + + def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: + self._annotate_args(expression) + + bracket_arg = expression.expressions[0] + this = expression.this + + if isinstance(bracket_arg, exp.Slice): + self._set_type(expression, this.type) + elif this.type.is_type(exp.DataType.Type.ARRAY): + contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN + self._set_type(expression, contained_type) + elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: + index = this.keys.index(bracket_arg) + value = seq_get(this.values, index) + value_type = value.type if value else exp.DataType.Type.UNKNOWN + self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) + else: + self._set_type(expression, exp.DataType.Type.UNKNOWN) + + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index ec3b3af..fc5c348 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -69,7 +69,11 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: _replace_int_predicate(expression.left) _replace_int_predicate(expression.right) - elif isinstance(expression, (exp.Where, exp.Having, exp.If)): + elif isinstance(expression, (exp.Where, exp.Having)) or ( + # We can't replace num in CASE x WHEN num ..., because it's not the full predicate + isinstance(expression, exp.If) + and not (isinstance(expression.parent, exp.Case) and expression.parent.this) + ): _replace_int_predicate(expression.this) return expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 849643c..30de75b 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False): node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) node = simplify_concat(node) + node = simplify_conditionals(node) if constant_propagation: node = propagate_constants(node, root) @@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: return expression if l.__class__ in INVERSE_DATE_OPS: + l = t.cast(exp.IntervalOp, l) a = l.this b = l.interval() else: + l = t.cast(exp.Binary, l) a, b = l.left, l.right if not a_predicate(a) and b_predicate(b): @@ -695,6 +698,32 @@ def simplify_concat(expression): return concat_type(expressions=new_args) +def simplify_conditionals(expression): + """Simplifies expressions like IF, CASE if their condition is statically known.""" + if isinstance(expression, exp.Case): + this = expression.this + for case in expression.args["ifs"]: + cond = case.this + if this: + # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... + cond = cond.replace(this.pop().eq(cond)) + + if always_true(cond): + return case.args["true"] + + if always_false(cond): + case.pop() + if not expression.args["ifs"]: + return expression.args.get("default") or exp.null() + elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): + if always_true(expression.this): + return expression.args["true"] + if always_false(expression.this): + return expression.args.get("false") or exp.null() + + return expression + + DateRange = t.Tuple[datetime.date, datetime.date] @@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: else: return expression + l = t.cast(exp.DateTrunc, l) unit = l.unit.name.lower() date = extract_date(r) @@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: rs = expression.expressions if rs and all(_is_datetrunc_predicate(l, r) for r in rs): + l = t.cast(exp.DateTrunc, l) unit = l.unit.name.lower() ranges = [] @@ -852,6 +883,10 @@ def always_true(expression): ) +def always_false(expression): + return is_false(expression) or is_null(expression) + + def is_complement(a, b): return isinstance(b, exp.Not) and b.this == a diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 8de76ca..b7f91ab 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -313,6 +313,7 @@ class Parser(metaclass=_Parser): TokenType.UNIQUE, TokenType.UNPIVOT, TokenType.UPDATE, + TokenType.USE, TokenType.VOLATILE, TokenType.WINDOW, *CREATABLES, @@ -629,11 +630,14 @@ class Parser(metaclass=_Parser): "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), - "CHARACTER SET": lambda self: self._parse_character_set(), + "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), + "CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs), "CHECKSUM": lambda self: self._parse_checksum(), "CLUSTER BY": lambda self: self._parse_cluster(), "CLUSTERED": lambda self: self._parse_clustered_by(), - "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), + "COLLATE": lambda self, **kwargs: self._parse_property_assignment( + exp.CollateProperty, **kwargs + ), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), "COPY": lambda self: self._parse_copy_property(), "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), @@ -1443,8 +1447,8 @@ class Parser(metaclass=_Parser): if self._match_texts(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.text.upper()](self) - if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): - return self._parse_character_set(default=True) + if self._match(TokenType.DEFAULT) and self._match_texts(self.PROPERTY_PARSERS): + return self.PROPERTY_PARSERS[self._prev.text.upper()](self, default=True) if self._match_text_seq("COMPOUND", "SORTKEY"): return self._parse_sortkey(compound=True) @@ -1480,10 +1484,10 @@ class Parser(metaclass=_Parser): else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) - def _parse_property_assignment(self, exp_class: t.Type[E]) -> E: + def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E: self._match(TokenType.EQ) self._match(TokenType.ALIAS) - return self.expression(exp_class, this=self._parse_field()) + return self.expression(exp_class, this=self._parse_field(), **kwargs) def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]: properties = [] @@ -2426,9 +2430,9 @@ class Parser(metaclass=_Parser): table_alias: t.Optional[exp.TableAlias] = self.expression( exp.TableAlias, this=table, columns=columns ) - elif isinstance(this, exp.Subquery) and this.alias: - # Ensures parity between the Subquery's and the Lateral's "alias" args - table_alias = this.args["alias"].copy() + elif isinstance(this, (exp.Subquery, exp.Unnest)) and this.alias: + # We move the alias from the lateral's child node to the lateral itself + table_alias = this.args["alias"].pop() else: table_alias = self._parse_table_alias() @@ -2952,6 +2956,7 @@ class Parser(metaclass=_Parser): cube = None totals = None + index = self._index with_ = self._match(TokenType.WITH) if self._match(TokenType.ROLLUP): rollup = with_ or self._parse_wrapped_csv(self._parse_column) @@ -2966,6 +2971,8 @@ class Parser(metaclass=_Parser): elements["totals"] = True # type: ignore if not (grouping_sets or rollup or cube or totals): + if with_: + self._retreat(index) break return self.expression(exp.Group, **elements) # type: ignore @@ -3157,6 +3164,7 @@ class Parser(metaclass=_Parser): return self.expression( expression, + comments=self._prev.comments, this=this, distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), by_name=self._match_text_seq("BY", "NAME"), @@ -3619,6 +3627,32 @@ class Parser(metaclass=_Parser): anonymous: bool = False, optional_parens: bool = True, ) -> t.Optional[exp.Expression]: + # This allows us to also parse {fn <function>} syntax (Snowflake, MySQL support this) + # See: https://community.snowflake.com/s/article/SQL-Escape-Sequences + fn_syntax = False + if ( + self._match(TokenType.L_BRACE, advance=False) + and self._next + and self._next.text.upper() == "FN" + ): + self._advance(2) + fn_syntax = True + + func = self._parse_function_call( + functions=functions, anonymous=anonymous, optional_parens=optional_parens + ) + + if fn_syntax: + self._match(TokenType.R_BRACE) + + return func + + def _parse_function_call( + self, + functions: t.Optional[t.Dict[str, t.Callable]] = None, + anonymous: bool = False, + optional_parens: bool = True, + ) -> t.Optional[exp.Expression]: if not self._curr: return None @@ -3856,6 +3890,10 @@ class Parser(metaclass=_Parser): if not identity: this.set("expression", self._parse_bitwise()) + elif not this.args.get("start") and self._match(TokenType.NUMBER, advance=False): + args = self._parse_csv(self._parse_bitwise) + this.set("start", seq_get(args, 0)) + this.set("increment", seq_get(args, 1)) self._match_r_paren() @@ -4039,6 +4077,11 @@ class Parser(metaclass=_Parser): ) ) + if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: + self.raise_error("Expected ]") + elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE: + self.raise_error("Expected }") + # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs if bracket_kind == TokenType.L_BRACE: this = self.expression(exp.Struct, expressions=expressions) @@ -4048,11 +4091,6 @@ class Parser(metaclass=_Parser): expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET) this = self.expression(exp.Bracket, this=this, expressions=expressions) - if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: - self.raise_error("Expected ]") - elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE: - self.raise_error("Expected }") - self._add_comments(this) return self._parse_bracket(this) diff --git a/sqlglot/time.py b/sqlglot/time.py index 5f0848e..c286ec1 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -54,3 +54,606 @@ def format_time( chunks.append(chars) return "".join(mapping.get(chars, chars) for chars in chunks) + + +TIMEZONES = { + tz.lower() + for tz in ( + "Africa/Abidjan", + "Africa/Accra", + "Africa/Addis_Ababa", + "Africa/Algiers", + "Africa/Asmara", + "Africa/Asmera", + "Africa/Bamako", + "Africa/Bangui", + "Africa/Banjul", + "Africa/Bissau", + "Africa/Blantyre", + "Africa/Brazzaville", + "Africa/Bujumbura", + "Africa/Cairo", + "Africa/Casablanca", + "Africa/Ceuta", + "Africa/Conakry", + "Africa/Dakar", + "Africa/Dar_es_Salaam", + "Africa/Djibouti", + "Africa/Douala", + "Africa/El_Aaiun", + "Africa/Freetown", + "Africa/Gaborone", + "Africa/Harare", + "Africa/Johannesburg", + "Africa/Juba", + "Africa/Kampala", + "Africa/Khartoum", + "Africa/Kigali", + "Africa/Kinshasa", + "Africa/Lagos", + "Africa/Libreville", + "Africa/Lome", + "Africa/Luanda", + "Africa/Lubumbashi", + "Africa/Lusaka", + "Africa/Malabo", + "Africa/Maputo", + "Africa/Maseru", + "Africa/Mbabane", + "Africa/Mogadishu", + "Africa/Monrovia", + "Africa/Nairobi", + "Africa/Ndjamena", + "Africa/Niamey", + "Africa/Nouakchott", + "Africa/Ouagadougou", + "Africa/Porto-Novo", + "Africa/Sao_Tome", + "Africa/Timbuktu", + "Africa/Tripoli", + "Africa/Tunis", + "Africa/Windhoek", + "America/Adak", + "America/Anchorage", + "America/Anguilla", + "America/Antigua", + "America/Araguaina", + "America/Argentina/Buenos_Aires", + "America/Argentina/Catamarca", + "America/Argentina/ComodRivadavia", + "America/Argentina/Cordoba", + "America/Argentina/Jujuy", + "America/Argentina/La_Rioja", + "America/Argentina/Mendoza", + "America/Argentina/Rio_Gallegos", + "America/Argentina/Salta", + "America/Argentina/San_Juan", + "America/Argentina/San_Luis", + "America/Argentina/Tucuman", + "America/Argentina/Ushuaia", + "America/Aruba", + "America/Asuncion", + "America/Atikokan", + "America/Atka", + "America/Bahia", + "America/Bahia_Banderas", + "America/Barbados", + "America/Belem", + "America/Belize", + "America/Blanc-Sablon", + "America/Boa_Vista", + "America/Bogota", + "America/Boise", + "America/Buenos_Aires", + "America/Cambridge_Bay", + "America/Campo_Grande", + "America/Cancun", + "America/Caracas", + "America/Catamarca", + "America/Cayenne", + "America/Cayman", + "America/Chicago", + "America/Chihuahua", + "America/Ciudad_Juarez", + "America/Coral_Harbour", + "America/Cordoba", + "America/Costa_Rica", + "America/Creston", + "America/Cuiaba", + "America/Curacao", + "America/Danmarkshavn", + "America/Dawson", + "America/Dawson_Creek", + "America/Denver", + "America/Detroit", + "America/Dominica", + "America/Edmonton", + "America/Eirunepe", + "America/El_Salvador", + "America/Ensenada", + "America/Fort_Nelson", + "America/Fort_Wayne", + "America/Fortaleza", + "America/Glace_Bay", + "America/Godthab", + "America/Goose_Bay", + "America/Grand_Turk", + "America/Grenada", + "America/Guadeloupe", + "America/Guatemala", + "America/Guayaquil", + "America/Guyana", + "America/Halifax", + "America/Havana", + "America/Hermosillo", + "America/Indiana/Indianapolis", + "America/Indiana/Knox", + "America/Indiana/Marengo", + "America/Indiana/Petersburg", + "America/Indiana/Tell_City", + "America/Indiana/Vevay", + "America/Indiana/Vincennes", + "America/Indiana/Winamac", + "America/Indianapolis", + "America/Inuvik", + "America/Iqaluit", + "America/Jamaica", + "America/Jujuy", + "America/Juneau", + "America/Kentucky/Louisville", + "America/Kentucky/Monticello", + "America/Knox_IN", + "America/Kralendijk", + "America/La_Paz", + "America/Lima", + "America/Los_Angeles", + "America/Louisville", + "America/Lower_Princes", + "America/Maceio", + "America/Managua", + "America/Manaus", + "America/Marigot", + "America/Martinique", + "America/Matamoros", + "America/Mazatlan", + "America/Mendoza", + "America/Menominee", + "America/Merida", + "America/Metlakatla", + "America/Mexico_City", + "America/Miquelon", + "America/Moncton", + "America/Monterrey", + "America/Montevideo", + "America/Montreal", + "America/Montserrat", + "America/Nassau", + "America/New_York", + "America/Nipigon", + "America/Nome", + "America/Noronha", + "America/North_Dakota/Beulah", + "America/North_Dakota/Center", + "America/North_Dakota/New_Salem", + "America/Nuuk", + "America/Ojinaga", + "America/Panama", + "America/Pangnirtung", + "America/Paramaribo", + "America/Phoenix", + "America/Port-au-Prince", + "America/Port_of_Spain", + "America/Porto_Acre", + "America/Porto_Velho", + "America/Puerto_Rico", + "America/Punta_Arenas", + "America/Rainy_River", + "America/Rankin_Inlet", + "America/Recife", + "America/Regina", + "America/Resolute", + "America/Rio_Branco", + "America/Rosario", + "America/Santa_Isabel", + "America/Santarem", + "America/Santiago", + "America/Santo_Domingo", + "America/Sao_Paulo", + "America/Scoresbysund", + "America/Shiprock", + "America/Sitka", + "America/St_Barthelemy", + "America/St_Johns", + "America/St_Kitts", + "America/St_Lucia", + "America/St_Thomas", + "America/St_Vincent", + "America/Swift_Current", + "America/Tegucigalpa", + "America/Thule", + "America/Thunder_Bay", + "America/Tijuana", + "America/Toronto", + "America/Tortola", + "America/Vancouver", + "America/Virgin", + "America/Whitehorse", + "America/Winnipeg", + "America/Yakutat", + "America/Yellowknife", + "Antarctica/Casey", + "Antarctica/Davis", + "Antarctica/DumontDUrville", + "Antarctica/Macquarie", + "Antarctica/Mawson", + "Antarctica/McMurdo", + "Antarctica/Palmer", + "Antarctica/Rothera", + "Antarctica/South_Pole", + "Antarctica/Syowa", + "Antarctica/Troll", + "Antarctica/Vostok", + "Arctic/Longyearbyen", + "Asia/Aden", + "Asia/Almaty", + "Asia/Amman", + "Asia/Anadyr", + "Asia/Aqtau", + "Asia/Aqtobe", + "Asia/Ashgabat", + "Asia/Ashkhabad", + "Asia/Atyrau", + "Asia/Baghdad", + "Asia/Bahrain", + "Asia/Baku", + "Asia/Bangkok", + "Asia/Barnaul", + "Asia/Beirut", + "Asia/Bishkek", + "Asia/Brunei", + "Asia/Calcutta", + "Asia/Chita", + "Asia/Choibalsan", + "Asia/Chongqing", + "Asia/Chungking", + "Asia/Colombo", + "Asia/Dacca", + "Asia/Damascus", + "Asia/Dhaka", + "Asia/Dili", + "Asia/Dubai", + "Asia/Dushanbe", + "Asia/Famagusta", + "Asia/Gaza", + "Asia/Harbin", + "Asia/Hebron", + "Asia/Ho_Chi_Minh", + "Asia/Hong_Kong", + "Asia/Hovd", + "Asia/Irkutsk", + "Asia/Istanbul", + "Asia/Jakarta", + "Asia/Jayapura", + "Asia/Jerusalem", + "Asia/Kabul", + "Asia/Kamchatka", + "Asia/Karachi", + "Asia/Kashgar", + "Asia/Kathmandu", + "Asia/Katmandu", + "Asia/Khandyga", + "Asia/Kolkata", + "Asia/Krasnoyarsk", + "Asia/Kuala_Lumpur", + "Asia/Kuching", + "Asia/Kuwait", + "Asia/Macao", + "Asia/Macau", + "Asia/Magadan", + "Asia/Makassar", + "Asia/Manila", + "Asia/Muscat", + "Asia/Nicosia", + "Asia/Novokuznetsk", + "Asia/Novosibirsk", + "Asia/Omsk", + "Asia/Oral", + "Asia/Phnom_Penh", + "Asia/Pontianak", + "Asia/Pyongyang", + "Asia/Qatar", + "Asia/Qostanay", + "Asia/Qyzylorda", + "Asia/Rangoon", + "Asia/Riyadh", + "Asia/Saigon", + "Asia/Sakhalin", + "Asia/Samarkand", + "Asia/Seoul", + "Asia/Shanghai", + "Asia/Singapore", + "Asia/Srednekolymsk", + "Asia/Taipei", + "Asia/Tashkent", + "Asia/Tbilisi", + "Asia/Tehran", + "Asia/Tel_Aviv", + "Asia/Thimbu", + "Asia/Thimphu", + "Asia/Tokyo", + "Asia/Tomsk", + "Asia/Ujung_Pandang", + "Asia/Ulaanbaatar", + "Asia/Ulan_Bator", + "Asia/Urumqi", + "Asia/Ust-Nera", + "Asia/Vientiane", + "Asia/Vladivostok", + "Asia/Yakutsk", + "Asia/Yangon", + "Asia/Yekaterinburg", + "Asia/Yerevan", + "Atlantic/Azores", + "Atlantic/Bermuda", + "Atlantic/Canary", + "Atlantic/Cape_Verde", + "Atlantic/Faeroe", + "Atlantic/Faroe", + "Atlantic/Jan_Mayen", + "Atlantic/Madeira", + "Atlantic/Reykjavik", + "Atlantic/South_Georgia", + "Atlantic/St_Helena", + "Atlantic/Stanley", + "Australia/ACT", + "Australia/Adelaide", + "Australia/Brisbane", + "Australia/Broken_Hill", + "Australia/Canberra", + "Australia/Currie", + "Australia/Darwin", + "Australia/Eucla", + "Australia/Hobart", + "Australia/LHI", + "Australia/Lindeman", + "Australia/Lord_Howe", + "Australia/Melbourne", + "Australia/NSW", + "Australia/North", + "Australia/Perth", + "Australia/Queensland", + "Australia/South", + "Australia/Sydney", + "Australia/Tasmania", + "Australia/Victoria", + "Australia/West", + "Australia/Yancowinna", + "Brazil/Acre", + "Brazil/DeNoronha", + "Brazil/East", + "Brazil/West", + "CET", + "CST6CDT", + "Canada/Atlantic", + "Canada/Central", + "Canada/Eastern", + "Canada/Mountain", + "Canada/Newfoundland", + "Canada/Pacific", + "Canada/Saskatchewan", + "Canada/Yukon", + "Chile/Continental", + "Chile/EasterIsland", + "Cuba", + "EET", + "EST", + "EST5EDT", + "Egypt", + "Eire", + "Etc/GMT", + "Etc/GMT+0", + "Etc/GMT+1", + "Etc/GMT+10", + "Etc/GMT+11", + "Etc/GMT+12", + "Etc/GMT+2", + "Etc/GMT+3", + "Etc/GMT+4", + "Etc/GMT+5", + "Etc/GMT+6", + "Etc/GMT+7", + "Etc/GMT+8", + "Etc/GMT+9", + "Etc/GMT-0", + "Etc/GMT-1", + "Etc/GMT-10", + "Etc/GMT-11", + "Etc/GMT-12", + "Etc/GMT-13", + "Etc/GMT-14", + "Etc/GMT-2", + "Etc/GMT-3", + "Etc/GMT-4", + "Etc/GMT-5", + "Etc/GMT-6", + "Etc/GMT-7", + "Etc/GMT-8", + "Etc/GMT-9", + "Etc/GMT0", + "Etc/Greenwich", + "Etc/UCT", + "Etc/UTC", + "Etc/Universal", + "Etc/Zulu", + "Europe/Amsterdam", + "Europe/Andorra", + "Europe/Astrakhan", + "Europe/Athens", + "Europe/Belfast", + "Europe/Belgrade", + "Europe/Berlin", + "Europe/Bratislava", + "Europe/Brussels", + "Europe/Bucharest", + "Europe/Budapest", + "Europe/Busingen", + "Europe/Chisinau", + "Europe/Copenhagen", + "Europe/Dublin", + "Europe/Gibraltar", + "Europe/Guernsey", + "Europe/Helsinki", + "Europe/Isle_of_Man", + "Europe/Istanbul", + "Europe/Jersey", + "Europe/Kaliningrad", + "Europe/Kiev", + "Europe/Kirov", + "Europe/Kyiv", + "Europe/Lisbon", + "Europe/Ljubljana", + "Europe/London", + "Europe/Luxembourg", + "Europe/Madrid", + "Europe/Malta", + "Europe/Mariehamn", + "Europe/Minsk", + "Europe/Monaco", + "Europe/Moscow", + "Europe/Nicosia", + "Europe/Oslo", + "Europe/Paris", + "Europe/Podgorica", + "Europe/Prague", + "Europe/Riga", + "Europe/Rome", + "Europe/Samara", + "Europe/San_Marino", + "Europe/Sarajevo", + "Europe/Saratov", + "Europe/Simferopol", + "Europe/Skopje", + "Europe/Sofia", + "Europe/Stockholm", + "Europe/Tallinn", + "Europe/Tirane", + "Europe/Tiraspol", + "Europe/Ulyanovsk", + "Europe/Uzhgorod", + "Europe/Vaduz", + "Europe/Vatican", + "Europe/Vienna", + "Europe/Vilnius", + "Europe/Volgograd", + "Europe/Warsaw", + "Europe/Zagreb", + "Europe/Zaporozhye", + "Europe/Zurich", + "GB", + "GB-Eire", + "GMT", + "GMT+0", + "GMT-0", + "GMT0", + "Greenwich", + "HST", + "Hongkong", + "Iceland", + "Indian/Antananarivo", + "Indian/Chagos", + "Indian/Christmas", + "Indian/Cocos", + "Indian/Comoro", + "Indian/Kerguelen", + "Indian/Mahe", + "Indian/Maldives", + "Indian/Mauritius", + "Indian/Mayotte", + "Indian/Reunion", + "Iran", + "Israel", + "Jamaica", + "Japan", + "Kwajalein", + "Libya", + "MET", + "MST", + "MST7MDT", + "Mexico/BajaNorte", + "Mexico/BajaSur", + "Mexico/General", + "NZ", + "NZ-CHAT", + "Navajo", + "PRC", + "PST8PDT", + "Pacific/Apia", + "Pacific/Auckland", + "Pacific/Bougainville", + "Pacific/Chatham", + "Pacific/Chuuk", + "Pacific/Easter", + "Pacific/Efate", + "Pacific/Enderbury", + "Pacific/Fakaofo", + "Pacific/Fiji", + "Pacific/Funafuti", + "Pacific/Galapagos", + "Pacific/Gambier", + "Pacific/Guadalcanal", + "Pacific/Guam", + "Pacific/Honolulu", + "Pacific/Johnston", + "Pacific/Kanton", + "Pacific/Kiritimati", + "Pacific/Kosrae", + "Pacific/Kwajalein", + "Pacific/Majuro", + "Pacific/Marquesas", + "Pacific/Midway", + "Pacific/Nauru", + "Pacific/Niue", + "Pacific/Norfolk", + "Pacific/Noumea", + "Pacific/Pago_Pago", + "Pacific/Palau", + "Pacific/Pitcairn", + "Pacific/Pohnpei", + "Pacific/Ponape", + "Pacific/Port_Moresby", + "Pacific/Rarotonga", + "Pacific/Saipan", + "Pacific/Samoa", + "Pacific/Tahiti", + "Pacific/Tarawa", + "Pacific/Tongatapu", + "Pacific/Truk", + "Pacific/Wake", + "Pacific/Wallis", + "Pacific/Yap", + "Poland", + "Portugal", + "ROC", + "ROK", + "Singapore", + "Turkey", + "UCT", + "US/Alaska", + "US/Aleutian", + "US/Arizona", + "US/Central", + "US/East-Indiana", + "US/Eastern", + "US/Hawaii", + "US/Indiana-Starke", + "US/Michigan", + "US/Mountain", + "US/Pacific", + "US/Samoa", + "UTC", + "Universal", + "W-SU", + "WET", + "Zulu", + ) +} diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index c883858..9784c63 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1077,10 +1077,10 @@ class Tokenizer(metaclass=_Tokenizer): literal = "" while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: - literal += self._peek.upper() + literal += self._peek self._advance() - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, "")) + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal.upper(), "")) if token_type: self._add(TokenType.NUMBER, number_text) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 8feee52..e0fd68f 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -164,8 +164,9 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: + """Convert explode/posexplode into unnest (used in hive -> presto).""" + def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: - """Convert explode/posexplode into unnest (used in hive -> presto).""" if isinstance(expression, exp.Select): from sqlglot.optimizer.scope import Scope @@ -297,6 +298,7 @@ PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: + """Transforms percentiles by adding a WITHIN GROUP clause to them.""" if ( isinstance(expression, PERCENTILES) and not isinstance(expression.parent, exp.WithinGroup) @@ -311,6 +313,7 @@ def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expressi def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: + """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" if ( isinstance(expression, exp.WithinGroup) and isinstance(expression.this, PERCENTILES) @@ -324,6 +327,7 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: + """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" if isinstance(expression, exp.With) and expression.recursive: next_name = name_sequence("_c_") @@ -342,6 +346,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: + """Replace 'epoch' in casts by the equivalent date literal.""" if ( isinstance(expression, (exp.Cast, exp.TryCast)) and expression.name.lower() == "epoch" @@ -352,16 +357,8 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: return expression -def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Timestamp) and not expression.expression: - return exp.cast( - expression.this, - to=exp.DataType.Type.TIMESTAMP, - ) - return expression - - def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: + """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" if isinstance(expression, exp.Select): for join in expression.args.get("joins") or []: on = join.args.get("on") |