diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-20 09:22:25 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-20 09:22:25 +0000 |
commit | a45bbbb6f2fbd117d5d314e34e85afc2b48ad677 (patch) | |
tree | 35b576637338ae7cef217ddab721ad81aeb3f78c /sqlglot | |
parent | Releasing debian version 18.4.1-1. (diff) | |
download | sqlglot-a45bbbb6f2fbd117d5d314e34e85afc2b48ad677.tar.xz sqlglot-a45bbbb6f2fbd117d5d314e34e85afc2b48ad677.zip |
Merging upstream version 18.5.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 42 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 28 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 18 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/trino.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 1 | ||||
-rw-r--r-- | sqlglot/expressions.py | 19 | ||||
-rw-r--r-- | sqlglot/generator.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify.py | 2 | ||||
-rw-r--r-- | sqlglot/parser.py | 28 | ||||
-rw-r--r-- | sqlglot/schema.py | 3 | ||||
-rw-r--r-- | sqlglot/transforms.py | 37 |
21 files changed, 168 insertions, 73 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index d763ed0..6c71624 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -9,6 +9,7 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, binary_from_function, + date_add_interval_sql, datestrtodate_sql, format_time_lambda, inline_array_sql, @@ -28,19 +29,6 @@ from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") -def _date_add_sql( - data_type: str, kind: str -) -> t.Callable[[BigQuery.Generator, exp.Expression], str]: - def func(self: BigQuery.Generator, expression: exp.Expression) -> str: - this = self.sql(expression, "this") - unit = expression.args.get("unit") - unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression.copy(), unit=unit) - return f"{data_type}_{kind}({this}, {self.sql(interval)})" - - return func - - def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str: if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) @@ -187,6 +175,7 @@ def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True + SUPPORTS_USER_DEFINED_TYPES = False # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -278,8 +267,6 @@ class BigQuery(Dialect): LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True - SUPPORTS_USER_DEFINED_TYPES = False - FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": _parse_date, @@ -436,13 +423,13 @@ class BigQuery(Dialect): exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), - exp.DateAdd: _date_add_sql("DATE", "ADD"), + exp.DateAdd: date_add_interval_sql("DATE", "ADD"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateFromParts: rename_func("DATE"), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("DATE", "SUB"), - exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), - exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), + exp.DateSub: date_add_interval_sql("DATE", "SUB"), + exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"), + exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GroupConcat: rename_func("STRING_AGG"), @@ -484,13 +471,13 @@ class BigQuery(Dialect): exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") ), - exp.TimeAdd: _date_add_sql("TIME", "ADD"), - exp.TimeSub: _date_add_sql("TIME", "SUB"), - exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), - exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), + exp.TimeAdd: date_add_interval_sql("TIME", "ADD"), + exp.TimeSub: date_add_interval_sql("TIME", "SUB"), + exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), + exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), - exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), + exp.TsOrDsAdd: date_add_interval_sql("DATE", "ADD"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.Unhex: rename_func("FROM_HEX"), exp.Values: _derived_table_values_to_unnest, @@ -640,13 +627,6 @@ class BigQuery(Dialect): return super().attimezone_sql(expression) - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals - if expression.is_type("json"): - return f"JSON {self.sql(expression, 'this')}" - - return super().cast_sql(expression, safe_prefix=safe_prefix) - def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 4b36663..d552f4c 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -25,6 +25,7 @@ class ClickHouse(Dialect): NORMALIZE_FUNCTIONS: bool | str = False NULL_ORDERING = "nulls_are_last" STRICT_STRING_CONCAT = True + SUPPORTS_USER_DEFINED_TYPES = False class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] @@ -64,8 +65,6 @@ class ClickHouse(Dialect): } class Parser(parser.Parser): - SUPPORTS_USER_DEFINED_TYPES = False - FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index ff22547..d4811c5 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect): # Determines whether or not CONCAT's arguments must be strings STRICT_STRING_CONCAT = False + # Determines whether or not user-defined data types are supported + SUPPORTS_USER_DEFINED_TYPES = True + # Determines how function names are going to be normalized NORMALIZE_FUNCTIONS: bool | str = "upper" @@ -546,6 +549,19 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: return exp.TimestampTrunc(this=this, unit=unit) +def date_add_interval_sql( + data_type: str, kind: str +) -> t.Callable[[Generator, exp.Expression], str]: + def func(self: Generator, expression: exp.Expression) -> str: + this = self.sql(expression, "this") + unit = expression.args.get("unit") + unit = exp.var(unit.name.upper() if unit else "DAY") + interval = exp.Interval(this=expression.expression.copy(), unit=unit) + return f"{data_type}_{kind}({this}, {self.sql(interval)})" + + return func + + def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: return self.func( "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this @@ -736,5 +752,15 @@ def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: # Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon -def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str: +def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" + + +def is_parse_json(expression: exp.Expression) -> bool: + return isinstance(expression, exp.ParseJSON) or ( + isinstance(expression, exp.Cast) and expression.is_type("json") + ) + + +def isnull_to_is_null(args: t.List) -> exp.Expression: + return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index c811c86..87fb9b5 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -39,6 +39,7 @@ class Drill(Dialect): DATE_FORMAT = "'yyyy-MM-dd'" DATEINT_FORMAT = "'yyyyMMdd'" TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" + SUPPORTS_USER_DEFINED_TYPES = False TIME_MAPPING = { "y": "%Y", @@ -80,7 +81,6 @@ class Drill(Dialect): class Parser(parser.Parser): STRICT_CAST = False CONCAT_NULL_OUTPUTS_STRING = True - SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index bf657ed..ab7a26a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -105,6 +105,7 @@ def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str: class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" + SUPPORTS_USER_DEFINED_TYPES = False # https://duckdb.org/docs/sql/introduction.html#creating-a-new-table RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -135,7 +136,6 @@ class DuckDB(Dialect): class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True - SUPPORTS_USER_DEFINED_TYPES = False BITWISE = { **parser.Parser.BITWISE, @@ -158,6 +158,11 @@ class DuckDB(Dialect): "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "MEDIAN": lambda args: exp.PercentileCont( + this=seq_get(args, 0), expression=exp.Literal.number(0.5) + ), + "QUANTILE_CONT": exp.PercentileCont.from_arg_list, + "QUANTILE_DISC": exp.PercentileDisc.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) ), @@ -266,6 +271,9 @@ class DuckDB(Dialect): exp.cast(e.expression, "timestamp", copy=True), exp.cast(e.this, "timestamp", copy=True), ), + exp.ParseJSON: rename_func("JSON"), + exp.PercentileCont: rename_func("QUANTILE_CONT"), + exp.PercentileDisc: rename_func("QUANTILE_DISC"), exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, exp.RegexpReplace: regexp_replace_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 8b17c06..bec27d3 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, format_time_lambda, if_sql, + is_parse_json, left_to_substring_sql, locate_to_strposition, max_or_greatest, @@ -89,7 +90,7 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: this = expression.this - if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string: + if is_parse_json(this) and this.this.is_string: # Since FROM_JSON requires a nested type, we always wrap the json string with # an array to ensure that "naked" strings like "'a'" will be handled correctly wrapped_json = exp.Literal.string(f"[{this.this.name}]") @@ -150,6 +151,7 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True + SUPPORTS_USER_DEFINED_TYPES = False # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -222,7 +224,6 @@ class Hive(Dialect): class Parser(parser.Parser): LOG_DEFAULTS_TO_LN = True STRICT_CAST = False - SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 46e3f19..75660f8 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -6,8 +6,10 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, + date_add_interval_sql, datestrtodate_sql, format_time_lambda, + isnull_to_is_null, json_keyvalue_comma_sql, locate_to_strposition, max_or_greatest, @@ -99,6 +101,7 @@ class MySQL(Dialect): TIME_FORMAT = "'%Y-%m-%d %T'" DPIPE_IS_STRING_CONCAT = False + SUPPORTS_USER_DEFINED_TYPES = False # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions TIME_MAPPING = { @@ -129,6 +132,7 @@ class MySQL(Dialect): "ENUM": TokenType.ENUM, "FORCE": TokenType.FORCE, "IGNORE": TokenType.IGNORE, + "LOCK TABLES": TokenType.COMMAND, "LONGBLOB": TokenType.LONGBLOB, "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, @@ -141,6 +145,7 @@ class MySQL(Dialect): "START": TokenType.BEGIN, "SIGNED": TokenType.BIGINT, "SIGNED INTEGER": TokenType.BIGINT, + "UNLOCK TABLES": TokenType.COMMAND, "UNSIGNED": TokenType.UBIGINT, "UNSIGNED INTEGER": TokenType.UBIGINT, "YEAR": TokenType.YEAR, @@ -193,8 +198,6 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): - SUPPORTS_USER_DEFINED_TYPES = False - FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, TokenType.DATABASE, @@ -233,7 +236,12 @@ class MySQL(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, + "MONTHNAME": lambda args: exp.TimeToStr( + this=seq_get(args, 0), + format=exp.Literal.string("%B"), + ), "STR_TO_DATE": _str_to_date, } @@ -374,7 +382,7 @@ class MySQL(Dialect): self._match_texts({"INDEX", "KEY"}) this = self._parse_id_var(any_token=False) - type_ = self._match(TokenType.USING) and self._advance_any() and self._prev.text + index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text schema = self._parse_schema() options = [] @@ -414,7 +422,7 @@ class MySQL(Dialect): this=this, schema=schema, kind=kind, - type=type_, + index_type=index_type, options=options, ) @@ -558,6 +566,8 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.Stuff: rename_func("INSERT"), exp.TableSample: no_tablesample_sql, + exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), + exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 378df49..0a4926d 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -32,6 +32,7 @@ def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable: class Oracle(Dialect): ALIAS_POST_TABLESAMPLE = True + LOCKING_READS_SUPPORTED = True # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm RESOLVES_IDENTIFIERS_AS_UPPERCASE = True diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 5027013..d049d8e 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -381,6 +381,9 @@ class Postgres(Dialect): **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.ArrayContained: lambda self, e: self.binary(e, "<@"), + exp.ArrayContains: lambda self, e: self.binary(e, "@>"), + exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.Explode: rename_func("UNNEST"), @@ -401,10 +404,13 @@ class Postgres(Dialect): exp.Max: max_or_greatest, exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, - exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), - exp.ArrayContains: lambda self, e: self.binary(e, "@>"), - exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.Merge: transforms.preprocess([_remove_target_from_merge]), + exp.PercentileCont: transforms.preprocess( + [transforms.add_within_group_for_percentiles] + ), + exp.PercentileDisc: transforms.preprocess( + [transforms.add_within_group_for_percentiles] + ), exp.Pivot: no_pivot_sql, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 4b54e95..9ae4c32 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -237,6 +237,7 @@ class Presto(Dialect): this=seq_get(args, 0), charset=exp.Literal.string("utf-8") ), } + FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("TRIM") @@ -310,6 +311,7 @@ class Presto(Dialect): exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, + exp.ParseJSON: rename_func("JSON_PARSE"), exp.Last: _first_last_sql, exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, @@ -360,6 +362,7 @@ class Presto(Dialect): exp.WithinGroup: transforms.preprocess( [transforms.remove_within_group_for_percentiles] ), + exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]), } def interval_sql(self, expression: exp.Interval) -> str: diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 554cbd3..b4c7664 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -30,6 +30,8 @@ class Redshift(Postgres): # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + SUPPORTS_USER_DEFINED_TYPES = False + TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" TIME_MAPPING = { **Postgres.TIME_MAPPING, @@ -38,8 +40,6 @@ class Redshift(Postgres): } class Parser(Postgres.Parser): - SUPPORTS_USER_DEFINED_TYPES = False - FUNCTIONS = { **Postgres.Parser.FUNCTIONS, "ADD_MONTHS": lambda args: exp.DateAdd( diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 8d8183c..5aa946e 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -202,6 +202,7 @@ class Snowflake(Dialect): RESOLVES_IDENTIFIERS_AS_UPPERCASE = True NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" + SUPPORTS_USER_DEFINED_TYPES = False TIME_MAPPING = { "YYYY": "%Y", @@ -234,7 +235,6 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True - SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 4489b6b..56d33ba 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( binary_from_function, create_with_partitions_sql, format_time_lambda, + is_parse_json, pivot_column_names, rename_func, trim_sql, @@ -242,10 +243,11 @@ class Spark2(Hive): CREATE_FUNCTION_RETURN_AS = False def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"): + if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" return self.func("FROM_JSON", expression.this.this, schema) - if expression.is_type("json"): + + if is_parse_json(expression): return self.func("TO_JSON", expression.this) return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix) diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 0c953a1..3682ac7 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -5,6 +5,8 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): + SUPPORTS_USER_DEFINED_TYPES = False + class Generator(Presto.Generator): TRANSFORMS = { **Presto.Generator.TRANSFORMS, @@ -13,6 +15,3 @@ class Trino(Presto): class Tokenizer(Presto.Tokenizer): HEX_STRINGS = [("X'", "'")] - - class Parser(Presto.Parser): - SUPPORTS_USER_DEFINED_TYPES = False diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 19c586e..2299310 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -580,7 +580,6 @@ class TSQL(Dialect): ) class Generator(generator.Generator): - LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True QUERY_HINTS = False RETURNING_END = False diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 98afddc..1c3d42a 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1321,7 +1321,13 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # https://dev.mysql.com/doc/refman/8.0/en/create-table.html class IndexColumnConstraint(ColumnConstraintKind): - arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False} + arg_types = { + "this": False, + "schema": True, + "kind": False, + "index_type": False, + "options": False, + } class InlineLengthColumnConstraint(ColumnConstraintKind): @@ -1354,7 +1360,7 @@ class TitleColumnConstraint(ColumnConstraintKind): class UniqueColumnConstraint(ColumnConstraintKind): - arg_types = {"this": False} + arg_types = {"this": False, "index_type": False} class UppercaseColumnConstraint(ColumnConstraintKind): @@ -4366,6 +4372,10 @@ class Extract(Func): arg_types = {"this": True, "expression": True} +class Timestamp(Func): + arg_types = {"this": False, "expression": False} + + class TimestampAdd(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} @@ -4579,6 +4589,11 @@ class JSONArrayContains(Binary, Predicate, Func): _sql_names = ["JSON_ARRAY_CONTAINS"] +class ParseJSON(Func): + # BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE + _sql_names = ["PARSE_JSON", "JSON_PARSE"] + + class Least(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 399b48b..d086e8a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -705,7 +705,9 @@ class Generator: def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: this = self.sql(expression, "this") this = f" {this}" if this else "" - return f"UNIQUE{this}" + index_type = expression.args.get("index_type") + index_type = f" USING {index_type}" if index_type else "" + return f"UNIQUE{this}{index_type}" def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: return self.sql(expression, "this") @@ -2740,13 +2742,13 @@ class Generator: kind = f"{kind} INDEX" if kind else "INDEX" this = self.sql(expression, "this") this = f" {this}" if this else "" - type_ = self.sql(expression, "type") - type_ = f" USING {type_}" if type_ else "" + index_type = self.sql(expression, "index_type") + index_type = f" USING {index_type}" if index_type else "" schema = self.sql(expression, "schema") schema = f" {schema}" if schema else "" options = self.expressions(expression, key="options", sep=" ") options = f" {options}" if options else "" - return f"{kind}{this}{type_}{schema}{options}" + return f"{kind}{this}{index_type}{schema}{options}" def nvl2_sql(self, expression: exp.Nvl2) -> str: if self.NVL2_SUPPORTED: diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 6e15c6a..5fdbde8 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -60,8 +60,8 @@ def qualify( The qualified expression. """ schema = ensure_schema(schema, dialect=dialect) - expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) expression = normalize_identifiers(expression, dialect=dialect) + expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) if isolate_tables: expression = isolate_table_selects(expression, schema=schema) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f721582..06bc1eb 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -820,7 +820,9 @@ class Parser(metaclass=_Parser): SHOW_PARSERS: t.Dict[str, t.Callable] = {} - TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {} + TYPE_LITERAL_PARSERS = { + exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this), + } MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) @@ -848,6 +850,8 @@ class Parser(metaclass=_Parser): WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} + FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT} + ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} DISTINCT_TOKENS = {TokenType.DISTINCT} @@ -863,8 +867,6 @@ class Parser(metaclass=_Parser): LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False - SUPPORTS_USER_DEFINED_TYPES = True - # Whether or not ADD is present for each column added by ALTER TABLE ALTER_TABLE_ADD_COLUMN_KEYWORD = True @@ -892,6 +894,7 @@ class Parser(metaclass=_Parser): UNNEST_COLUMN_ONLY: bool = False ALIAS_POST_TABLESAMPLE: bool = False STRICT_STRING_CONCAT = False + SUPPORTS_USER_DEFINED_TYPES = True NORMALIZE_FUNCTIONS = "upper" NULL_ORDERING: str = "nulls_are_small" SHOW_TRIE: t.Dict = {} @@ -2692,7 +2695,7 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(self._parse_primary) else: expressions = None - num = self._parse_number() + num = self._parse_primary() if self._match_text_seq("BUCKET"): bucket_numerator = self._parse_number() @@ -2914,6 +2917,10 @@ class Parser(metaclass=_Parser): ) connect = self._parse_conjunction() self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR") + + if not start and self._match(TokenType.START_WITH): + start = self._parse_conjunction() + return self.expression(exp.Connect, start=start, connect=connect) def _parse_order( @@ -2985,7 +2992,7 @@ class Parser(metaclass=_Parser): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" - count = self._parse_number() + count = self._parse_field(tokens=self.FETCH_TOKENS) percent = self._match(TokenType.PERCENT) self._match_set((TokenType.ROW, TokenType.ROWS)) @@ -3272,7 +3279,12 @@ class Parser(metaclass=_Parser): if tokens[0].token_type in self.TYPE_TOKENS: self._prev = tokens[0] elif self.SUPPORTS_USER_DEFINED_TYPES: - return exp.DataType.build(identifier.name, udt=True) + type_name = identifier.name + + while self._match(TokenType.DOT): + type_name = f"{type_name}.{self._advance_any() and self._prev.text}" + + return exp.DataType.build(type_name, udt=True) else: return None else: @@ -3816,7 +3828,9 @@ class Parser(metaclass=_Parser): def _parse_unique(self) -> exp.UniqueColumnConstraint: self._match_text_seq("KEY") return self.expression( - exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)) + exp.UniqueColumnConstraint, + this=self._parse_schema(self._parse_id_var(any_token=False)), + index_type=self._match(TokenType.USING) and self._advance_any() and self._prev.text, ) def _parse_key_constraint_options(self) -> t.List[str]: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f028f5a..f0b279b 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -398,9 +398,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): """ if schema_type not in self._type_mapping_cache: dialect = dialect or self.dialect + udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES try: - expression = exp.DataType.build(schema_type, dialect=dialect) + expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) self._type_mapping_cache[schema_type] = expression except AttributeError: in_dialect = f" in dialect {dialect}" if dialect else "" diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 66ab884..70b9a31 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -224,10 +224,27 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression: return expression +PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) + + +def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, PERCENTILES) + and not isinstance(expression.parent, exp.WithinGroup) + and expression.expression + ): + column = expression.this.pop() + expression.set("this", expression.expression.pop()) + order = exp.Order(expressions=[exp.Ordered(this=column)]) + expression = exp.WithinGroup(this=expression, expression=order) + + return expression + + def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: if ( isinstance(expression, exp.WithinGroup) - and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) + and isinstance(expression.this, PERCENTILES) and isinstance(expression.expression, exp.Order) ): quantile = expression.this.this @@ -294,10 +311,13 @@ def preprocess( transforms_handler = self.TRANSFORMS.get(type(expression)) if transforms_handler: - # Ensures we don't enter an infinite loop. This can happen when the original expression - # has the same type as the final expression and there's no _sql method available for it, - # because then it'd re-enter _to_sql. if expression_type is type(expression): + if isinstance(expression, exp.Func): + return self.function_fallback_sql(expression) + + # Ensures we don't enter an infinite loop. This can happen when the original expression + # has the same type as the final expression and there's no _sql method available for it, + # because then it'd re-enter _to_sql. raise ValueError( f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." ) @@ -307,3 +327,12 @@ def preprocess( raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") return _to_sql + + +def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Timestamp) and not expression.expression: + return exp.cast( + expression.this, + to=exp.DataType.Type.TIMESTAMP, + ) + return expression |