From 8f88a01462641cbf930b3c43b780565d0fb7d37e Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Jun 2023 20:53:34 +0200 Subject: Merging upstream version 16.4.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 22 ++++++++++++ sqlglot/dialects/clickhouse.py | 2 +- sqlglot/dialects/dialect.py | 78 +++++++++++++++++++++++++++++++++++++----- sqlglot/dialects/duckdb.py | 11 +++++- sqlglot/dialects/hive.py | 23 ++++++++----- sqlglot/dialects/mysql.py | 7 +++- sqlglot/dialects/presto.py | 20 +++++------ sqlglot/dialects/redshift.py | 3 ++ sqlglot/dialects/snowflake.py | 5 ++- sqlglot/dialects/sqlite.py | 3 ++ sqlglot/dialects/teradata.py | 13 +++---- 11 files changed, 149 insertions(+), 38 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 2166e65..52d4a88 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -4,6 +4,7 @@ import re import typing as t from sqlglot import exp, generator, parser, tokens, transforms +from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, @@ -106,6 +107,9 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + TIME_MAPPING = { "%D": "%m/%d/%y", } @@ -126,6 +130,20 @@ class BigQuery(Dialect): "TZH": "%z", } + @classmethod + def normalize_identifier(cls, expression: E) -> E: + # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). + # The following check is essentially a heuristic to detect tables based on whether or + # not they're qualified. + if ( + isinstance(expression, exp.Identifier) + and not (isinstance(expression.parent, exp.Table) and expression.parent.db) + and not expression.meta.get("is_table") + ): + expression.set("this", expression.this.lower()) + + return expression + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"', '"""', "'''"] COMMENTS = ["--", "#", ("/*", "*/")] @@ -176,6 +194,7 @@ class BigQuery(Dialect): "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), + "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), @@ -201,6 +220,7 @@ class BigQuery(Dialect): "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), + "TO_JSON_STRING": exp.JSONFormat.from_arg_list, } FUNCTION_PARSERS = { @@ -289,6 +309,8 @@ class BigQuery(Dialect): exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateStrToDate: datestrtodate_sql, exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), + exp.JSONFormat: rename_func("TO_JSON_STRING"), + exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index cfa9a7e..efaf34c 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -345,7 +345,7 @@ class ClickHouse(Dialect): "CONCAT", *[ exp.func("if", e.is_(exp.null()), e, exp.cast(e, "text")) - for e in expression.expressions + for e in t.cast(t.List[exp.Condition], expression.expressions) ], ) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index f5d523b..0e25b9b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -4,6 +4,7 @@ import typing as t from enum import Enum from sqlglot import exp +from sqlglot._typing import E from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser @@ -11,14 +12,6 @@ from sqlglot.time import format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. -# https://docs.snowflake.com/en/sql-reference/identifiers-syntax -RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} - class Dialects(str, Enum): DIALECT = "" @@ -117,6 +110,9 @@ class _Dialect(type): "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], } + if enum not in ("", "bigquery"): + dialect_properties["SELECT_KINDS"] = () + # Pass required dialect properties to the tokenizer, parser and generator classes for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): for name, value in dialect_properties.items(): @@ -126,6 +122,8 @@ class _Dialect(type): if not klass.STRICT_STRING_CONCAT: klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + klass.generator_class.can_identify = klass.can_identify + return klass @@ -139,6 +137,10 @@ class Dialect(metaclass=_Dialect): # Determines whether or not the table alias comes after tablesample ALIAS_POST_TABLESAMPLE = False + # Determines whether or not unquoted identifiers are resolved as uppercase + # When set to None, it means that the dialect treats all identifiers as case-insensitive + RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False + # Determines whether or not an unquoted identifier can start with a digit IDENTIFIERS_CAN_START_WITH_DIGIT = False @@ -213,6 +215,66 @@ class Dialect(metaclass=_Dialect): return expression + @classmethod + def normalize_identifier(cls, expression: E) -> E: + """ + Normalizes an unquoted identifier to either lower or upper case, thus essentially + making it case-insensitive. If a dialect treats all identifiers as case-insensitive, + they will be normalized regardless of being quoted or not. + """ + if isinstance(expression, exp.Identifier) and ( + not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None + ): + expression.set( + "this", + expression.this.upper() + if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE + else expression.this.lower(), + ) + + return expression + + @classmethod + def case_sensitive(cls, text: str) -> bool: + """Checks if text contains any case sensitive characters, based on the dialect's rules.""" + if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: + return False + + unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + return any(unsafe(char) for char in text) + + @classmethod + def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: + """Checks if text can be identified given an identify option. + + Args: + text: The text to check. + identify: + "always" or `True`: Always returns true. + "safe": True if the identifier is case-insensitive. + + Returns: + Whether or not the given text can be identified. + """ + if identify is True or identify == "always": + return True + + if identify == "safe": + return not cls.case_sensitive(text) + + return False + + @classmethod + def quote_identifier(cls, expression: E, identify: bool = True) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + + return expression + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenize(sql), sql) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f0c1820..093a01c 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -85,9 +85,17 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract ) +def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: + sql = self.func("TO_JSON", expression.this, expression.args.get("options")) + return f"CAST({sql} AS TEXT)" + + class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" + # https://duckdb.org/docs/sql/introduction.html#creating-a-new-table + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -167,7 +175,7 @@ class DuckDB(Dialect): **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) - if isinstance(seq_get(e.expressions, 0), exp.Select) + if e.expressions and e.expressions[0].find(exp.Select) else rename_func("LIST_VALUE")(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, @@ -192,6 +200,7 @@ class DuckDB(Dialect): exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONFormat: _json_format_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 8847119..6bca610 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -86,13 +86,17 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: this = expression.this - if not this.type: - from sqlglot.optimizer.annotate_types import annotate_types + if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string: + # Since FROM_JSON requires a nested type, we always wrap the json string with + # an array to ensure that "naked" strings like "'a'" will be handled correctly + wrapped_json = exp.Literal.string(f"[{this.this.name}]") - annotate_types(this) + from_json = self.func("FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json)) + to_json = self.func("TO_JSON", from_json) + + # This strips the [, ] delimiters of the dummy array printed by TO_JSON + return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1") - if this.type.is_type("json"): - return self.sql(this) return self.func("TO_JSON", this, expression.args.get("options")) @@ -153,6 +157,9 @@ class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True + # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + TIME_MAPPING = { "y": "%Y", "Y": "%Y", @@ -268,9 +275,9 @@ class Hive(Dialect): QUERY_MODIFIER_PARSERS = { **parser.Parser.QUERY_MODIFIER_PARSERS, - "distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"), - "sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"), - "cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"), + "cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), + "distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), + "sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY), } def _parse_types( diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index d2462e1..1dd2096 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -123,6 +123,8 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "CHARSET": TokenType.CHARACTER_SET, + "FORCE": TokenType.FORCE, + "IGNORE": TokenType.IGNORE, "LONGBLOB": TokenType.LONGBLOB, "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, @@ -180,6 +182,9 @@ class MySQL(Dialect): class Parser(parser.Parser): FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} + TABLE_ALIAS_TOKENS = ( + parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS + ) FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -389,7 +394,7 @@ class MySQL(Dialect): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False JOIN_HINTS = False - TABLE_HINTS = False + TABLE_HINTS = True TRANSFORMS = { **generator.Generator.TRANSFORMS, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index a8a9884..265780e 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -103,24 +103,15 @@ def _str_to_time_sql( def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): - return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" - return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" + return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") + return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto") def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): - this = self.func( - "DATE_PARSE", - self.func( - "SUBSTR", - this if this.is_string else exp.cast(this, "VARCHAR"), - exp.Literal.number(1), - exp.Literal.number(10), - ), - Presto.DATE_FORMAT, - ) + this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE") return self.func( "DATE_ADD", @@ -181,6 +172,11 @@ class Presto(Dialect): TIME_MAPPING = MySQL.TIME_MAPPING STRICT_STRING_CONCAT = True + # https://github.com/trinodb/trino/issues/17 + # https://github.com/trinodb/trino/issues/12289 + # https://github.com/prestodb/presto/issues/2863 + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a7e25fa..db6cc3f 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -14,6 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx class Redshift(Postgres): + # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" TIME_MAPPING = { **Postgres.TIME_MAPPING, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 148b6d8..d488d7d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -167,6 +167,8 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression: class Snowflake(Dialect): + # https://docs.snowflake.com/en/sql-reference/identifiers-syntax + RESOLVES_IDENTIFIERS_AS_UPPERCASE = True NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" @@ -283,11 +285,12 @@ class Snowflake(Dialect): "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, + "SAMPLE": TokenType.TABLE_SAMPLE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, - "SAMPLE": TokenType.TABLE_SAMPLE, + "TOP": TokenType.TOP, } SINGLE_TOKENS = { diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 3b837ea..803f361 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -59,6 +59,9 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: class SQLite(Dialect): + # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index d5e5dd8..d9a5417 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -31,18 +31,19 @@ class Teradata(Dialect): # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "^=": TokenType.NEQ, "BYTEINT": TokenType.SMALLINT, - "SEL": TokenType.SELECT, + "GE": TokenType.GTE, + "GT": TokenType.GT, "INS": TokenType.INSERT, - "MOD": TokenType.MOD, - "LT": TokenType.LT, "LE": TokenType.LTE, - "GT": TokenType.GT, - "GE": TokenType.GTE, - "^=": TokenType.NEQ, + "LT": TokenType.LT, + "MOD": TokenType.MOD, "NE": TokenType.NEQ, "NOT=": TokenType.NEQ, + "SEL": TokenType.SELECT, "ST_GEOMETRY": TokenType.GEOMETRY, + "TOP": TokenType.TOP, } # Teradata does not support % as a modulo operator -- cgit v1.2.3