diff options
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 78 |
1 files changed, 70 insertions, 8 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index f5d523b..0e25b9b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -4,6 +4,7 @@ import typing as t from enum import Enum from sqlglot import exp +from sqlglot._typing import E from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser @@ -11,14 +12,6 @@ from sqlglot.time import format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. -# https://docs.snowflake.com/en/sql-reference/identifiers-syntax -RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} - class Dialects(str, Enum): DIALECT = "" @@ -117,6 +110,9 @@ class _Dialect(type): "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], } + if enum not in ("", "bigquery"): + dialect_properties["SELECT_KINDS"] = () + # Pass required dialect properties to the tokenizer, parser and generator classes for subclass in (klass.tokenizer_class, klass.parser_class, klass.generator_class): for name, value in dialect_properties.items(): @@ -126,6 +122,8 @@ class _Dialect(type): if not klass.STRICT_STRING_CONCAT: klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + klass.generator_class.can_identify = klass.can_identify + return klass @@ -139,6 +137,10 @@ class Dialect(metaclass=_Dialect): # Determines whether or not the table alias comes after tablesample ALIAS_POST_TABLESAMPLE = False + # Determines whether or not unquoted identifiers are resolved as uppercase + # When set to None, it means that the dialect treats all identifiers as case-insensitive + RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False + # Determines whether or not an unquoted identifier can start with a digit IDENTIFIERS_CAN_START_WITH_DIGIT = False @@ -213,6 +215,66 @@ class Dialect(metaclass=_Dialect): return expression + @classmethod + def normalize_identifier(cls, expression: E) -> E: + """ + Normalizes an unquoted identifier to either lower or upper case, thus essentially + making it case-insensitive. If a dialect treats all identifiers as case-insensitive, + they will be normalized regardless of being quoted or not. + """ + if isinstance(expression, exp.Identifier) and ( + not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None + ): + expression.set( + "this", + expression.this.upper() + if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE + else expression.this.lower(), + ) + + return expression + + @classmethod + def case_sensitive(cls, text: str) -> bool: + """Checks if text contains any case sensitive characters, based on the dialect's rules.""" + if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None: + return False + + unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + return any(unsafe(char) for char in text) + + @classmethod + def can_identify(cls, text: str, identify: str | bool = "safe") -> bool: + """Checks if text can be identified given an identify option. + + Args: + text: The text to check. + identify: + "always" or `True`: Always returns true. + "safe": True if the identifier is case-insensitive. + + Returns: + Whether or not the given text can be identified. + """ + if identify is True or identify == "always": + return True + + if identify == "safe": + return not cls.case_sensitive(text) + + return False + + @classmethod + def quote_identifier(cls, expression: E, identify: bool = True) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + + return expression + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenize(sql), sql) |