diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:11:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:12:02 +0000 |
commit | 8d36f5966675e23bee7026ba37ae0647fbf47300 (patch) | |
tree | df4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/dialects/dialect.py | |
parent | Releasing debian version 22.2.0-1. (diff) | |
download | sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip |
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 115 |
1 files changed, 91 insertions, 24 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 599505c..81057c2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -31,6 +31,7 @@ class Dialects(str, Enum): DIALECT = "" + ATHENA = "athena" BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" DATABRICKS = "databricks" @@ -42,6 +43,7 @@ class Dialects(str, Enum): ORACLE = "oracle" POSTGRES = "postgres" PRESTO = "presto" + PRQL = "prql" REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SPARK = "spark" @@ -108,11 +110,18 @@ class _Dialect(type): klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) - klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} + base = seq_get(bases, 0) + base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) + base_parser = (getattr(base, "parser_class", Parser),) + base_generator = (getattr(base, "generator_class", Generator),) - klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) - klass.parser_class = getattr(klass, "Parser", Parser) - klass.generator_class = getattr(klass, "Generator", Generator) + klass.tokenizer_class = klass.__dict__.get( + "Tokenizer", type("Tokenizer", base_tokenizer, {}) + ) + klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) + klass.generator_class = klass.__dict__.get( + "Generator", type("Generator", base_generator, {}) + ) klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( @@ -134,9 +143,31 @@ class _Dialect(type): klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) + if "\\" in klass.tokenizer_class.STRING_ESCAPES: + klass.UNESCAPED_SEQUENCES = { + "\\a": "\a", + "\\b": "\b", + "\\f": "\f", + "\\n": "\n", + "\\r": "\r", + "\\t": "\t", + "\\v": "\v", + "\\\\": "\\", + **klass.UNESCAPED_SEQUENCES, + } + + klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} + if enum not in ("", "bigquery"): klass.generator_class.SELECT_KINDS = () + if enum not in ("", "databricks", "hive", "spark", "spark2"): + modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() + for modifier in ("cluster", "distribute", "sort"): + modifier_transforms.pop(modifier, None) + + klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms + if not klass.SUPPORTS_SEMI_ANTI_JOIN: klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { TokenType.ANTI, @@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect): False: Disables function name normalization. """ - LOG_BASE_FIRST = True - """Whether the base comes first in the `LOG` function.""" + LOG_BASE_FIRST: t.Optional[bool] = True + """ + Whether the base comes first in the `LOG` function. + Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) + """ NULL_ORDERING = "nulls_are_small" """ @@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect): If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. """ - ESCAPE_SEQUENCES: t.Dict[str, str] = {} - """Mapping of an unescaped escape sequence to the corresponding character.""" + UNESCAPED_SEQUENCES: t.Dict[str, str] = {} + """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" PSEUDOCOLUMNS: t.Set[str] = set() """ @@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect): INVERSE_TIME_MAPPING: t.Dict[str, str] = {} INVERSE_TIME_TRIE: t.Dict = {} - INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} + ESCAPED_SEQUENCES: t.Dict[str, str] = {} # Delimiters for string literals and identifiers QUOTE_START = "'" @@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> return "" -def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: +def str_position_sql( + self: Generator, expression: exp.StrPosition, generate_instance: bool = False +) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") position = self.sql(expression, "position") + instance = expression.args.get("instance") if generate_instance else None + position_offset = "" + if position: - return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" - return f"STRPOS({this}, {substr})" + # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects + this = self.func("SUBSTR", this, position) + position_offset = f" + {position} - 1" + + return self.func("STRPOS", this, substr, instance) + position_offset def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: @@ -689,9 +731,7 @@ def build_date_delta_with_interval( if expression and expression.is_string: expression = exp.Literal.number(expression.this) - return expression_class( - this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) - ) + return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) return _builder @@ -710,18 +750,14 @@ def date_add_interval_sql( ) -> t.Callable[[Generator, exp.Expression], str]: def func(self: Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") - unit = expression.args.get("unit") - unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression, unit=unit) + interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: - return self.func( - "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this - ) + return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: @@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return self.func( name, - exp.var(expression.text("unit").upper() or "DAY"), + unit_to_var(expression), expression.expression, expression.this, ) @@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return _delta_sql +def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, exp.Placeholder): + return unit + if unit: + return exp.Literal.string(unit.name) + return exp.Literal.string(default) if default else None + + +def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, (exp.Var, exp.Placeholder)): + return unit + return exp.Var(this=default) if default else None + + def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: trunc_curr_date = exp.func("date_trunc", "month", expression.this) plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") @@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: def build_json_extract_path( - expr_type: t.Type[F], zero_based_indexing: bool = True + expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False ) -> t.Callable[[t.List], F]: def _builder(args: t.List) -> F: segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] @@ -1018,7 +1072,11 @@ def build_json_extract_path( # This is done to avoid failing in the expression validator due to the arg count del args[2:] - return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) + return expr_type( + this=seq_get(args, 0), + expression=exp.JSONPath(expressions=segments), + only_json_types=arrow_req_json_type, + ) return _builder @@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s unnest = exp.Unnest(expressions=[expression.this]) filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) return self.sql(exp.Array(expressions=[filtered])) + + +def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("nlsparam"), + ) |