diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
commit | 38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch) | |
tree | 64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot/dialects/dialect.py | |
parent | Releasing debian version 20.4.0-1. (diff) | |
download | sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip |
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 119 |
1 files changed, 94 insertions, 25 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b7eef45..7664c40 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect): ALIAS_POST_TABLESAMPLE = False """Determines whether or not the table alias comes after tablesample.""" + TABLESAMPLE_SIZE_IS_PERCENT = False + """Determines whether or not a size in the table sample clause represents percentage.""" + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE """Specifies the strategy according to which identifiers should be normalized.""" @@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect): For example, such columns may be excluded from `SELECT *` queries. """ + PREFER_CTE_ALIAS_COLUMN = False + """ + Some dialects, such as Snowflake, allow you to reference a CTE column alias in the + HAVING clause of the CTE. This flag will cause the CTE alias columns to override + any projection aliases in the subquery. + + For example, + WITH y(c) AS ( + SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 + ) SELECT c FROM y; + + will be rewritten as + + WITH y(c) AS ( + SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 + ) SELECT c FROM y; + """ + # --- Autofilled --- tokenizer_class = Tokenizer @@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect): result = cls.get(dialect_name.strip()) if not result: - raise ValueError(f"Unknown dialect '{dialect_name}'.") + from difflib import get_close_matches + + similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" + if similar: + similar = f" Did you mean {similar}?" + + raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") return result(**kwargs) @@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: n = self.sql(expression, "this") d = self.sql(expression, "expression") - return f"IF({d} <> 0, {n} / {d}, NULL)" + return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: @@ -695,7 +722,7 @@ def date_add_interval_sql( def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: return self.func( - "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this + "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this ) @@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: return self.func("STRPTIME", expression.this, self.format_time(expression)) -def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: - def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: - _dialect = Dialect.get_or_raise(dialect) - time_format = self.format_time(expression) - if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): - return self.sql( - exp.cast( - exp.StrToTime(this=expression.this, format=expression.args["format"]), - "date", - ) - ) - return self.sql(exp.cast(expression.this, "date")) - - return _ts_or_ds_to_date_sql - - def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) @@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" -# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon -def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: - return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" - - def is_parse_json(expression: exp.Expression) -> bool: return isinstance(expression, exp.ParseJSON) or ( isinstance(expression, exp.Cast) and expression.is_type("json") @@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE expression = ts_or_ds_add_cast(expression) return self.func( - name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this + name, + exp.var(expression.text("unit").upper() or "DAY"), + expression.expression, + expression.this, ) return _delta_sql + + +def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: + from sqlglot.optimizer.simplify import simplify + + # Makes sure the path will be evaluated correctly at runtime to include the path root. + # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. + path = expression.expression + path = exp.func( + "if", + exp.func("startswith", path, "'['"), + exp.func("concat", "'$'", path), + exp.func("concat", "'$.'", path), + ) + + expression.expression.replace(simplify(path)) + return expression + + +def path_to_jsonpath( + name: str = "JSON_EXTRACT", +) -> t.Callable[[Generator, exp.GetPath], str]: + def _transform(self: Generator, expression: exp.GetPath) -> str: + return rename_func(name)(self, prepend_dollar_to_path(expression)) + + return _transform + + +def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: + trunc_curr_date = exp.func("date_trunc", "month", expression.this) + plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") + minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") + + return self.sql(exp.cast(minus_one_day, "date")) + + +def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: + """Remove table refs from columns in when statements.""" + alias = expression.this.args.get("alias") + + normalize = ( + lambda identifier: self.dialect.normalize_identifier(identifier).name + if identifier + else None + ) + + targets = {normalize(expression.this.this)} + + if alias: + targets.add(normalize(alias.this)) + + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node, + copy=False, + ) + + return self.merge_sql(expression) |