summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
commit38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch)
tree64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot/dialects/dialect.py
parentReleasing debian version 20.4.0-1. (diff)
downloadsqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz
sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py119
1 files changed, 94 insertions, 25 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index b7eef45..7664c40 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
ALIAS_POST_TABLESAMPLE = False
"""Determines whether or not the table alias comes after tablesample."""
+ TABLESAMPLE_SIZE_IS_PERCENT = False
+ """Determines whether or not a size in the table sample clause represents percentage."""
+
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
"""Specifies the strategy according to which identifiers should be normalized."""
@@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect):
For example, such columns may be excluded from `SELECT *` queries.
"""
+ PREFER_CTE_ALIAS_COLUMN = False
+ """
+ Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
+ HAVING clause of the CTE. This flag will cause the CTE alias columns to override
+ any projection aliases in the subquery.
+
+ For example,
+ WITH y(c) AS (
+ SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
+ ) SELECT c FROM y;
+
+ will be rewritten as
+
+ WITH y(c) AS (
+ SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
+ ) SELECT c FROM y;
+ """
+
# --- Autofilled ---
tokenizer_class = Tokenizer
@@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect):
result = cls.get(dialect_name.strip())
if not result:
- raise ValueError(f"Unknown dialect '{dialect_name}'.")
+ from difflib import get_close_matches
+
+ similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
+ if similar:
+ similar = f" Did you mean {similar}?"
+
+ raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
return result(**kwargs)
@@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
- return f"IF({d} <> 0, {n} / {d}, NULL)"
+ return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
@@ -695,7 +722,7 @@ def date_add_interval_sql(
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
- "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
+ "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
)
@@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
-def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
- def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
- _dialect = Dialect.get_or_raise(dialect)
- time_format = self.format_time(expression)
- if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
- return self.sql(
- exp.cast(
- exp.StrToTime(this=expression.this, format=expression.args["format"]),
- "date",
- )
- )
- return self.sql(exp.cast(expression.this, "date"))
-
- return _ts_or_ds_to_date_sql
-
-
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
-# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
-def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
- return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
-
-
def is_parse_json(expression: exp.Expression) -> bool:
return isinstance(expression, exp.ParseJSON) or (
isinstance(expression, exp.Cast) and expression.is_type("json")
@@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
expression = ts_or_ds_add_cast(expression)
return self.func(
- name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
+ name,
+ exp.var(expression.text("unit").upper() or "DAY"),
+ expression.expression,
+ expression.this,
)
return _delta_sql
+
+
+def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
+ from sqlglot.optimizer.simplify import simplify
+
+ # Makes sure the path will be evaluated correctly at runtime to include the path root.
+ # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
+ path = expression.expression
+ path = exp.func(
+ "if",
+ exp.func("startswith", path, "'['"),
+ exp.func("concat", "'$'", path),
+ exp.func("concat", "'$.'", path),
+ )
+
+ expression.expression.replace(simplify(path))
+ return expression
+
+
+def path_to_jsonpath(
+ name: str = "JSON_EXTRACT",
+) -> t.Callable[[Generator, exp.GetPath], str]:
+ def _transform(self: Generator, expression: exp.GetPath) -> str:
+ return rename_func(name)(self, prepend_dollar_to_path(expression))
+
+ return _transform
+
+
+def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
+ trunc_curr_date = exp.func("date_trunc", "month", expression.this)
+ plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
+ minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
+
+ return self.sql(exp.cast(minus_one_day, "date"))
+
+
+def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
+ """Remove table refs from columns in when statements."""
+ alias = expression.this.args.get("alias")
+
+ normalize = (
+ lambda identifier: self.dialect.normalize_identifier(identifier).name
+ if identifier
+ else None
+ )
+
+ targets = {normalize(expression.this.this)}
+
+ if alias:
+ targets.add(normalize(alias.this))
+
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node,
+ copy=False,
+ )
+
+ return self.merge_sql(expression)