summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
commit376de8b6892deca7dc5d83035c047f1e13eb67ea (patch)
tree334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot/dialects/dialect.py
parentReleasing debian version 20.9.0-1. (diff)
downloadsqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz
sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 7664c40..6be991b 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -5,7 +5,6 @@ from enum import Enum, auto
from functools import reduce
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
@@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
-B = t.TypeVar("B", bound=exp.Binary)
-
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
+if t.TYPE_CHECKING:
+ from sqlglot._typing import B, E
+
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
@@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect):
):
expression.set(
"this",
- expression.this.upper()
- if self.normalization_strategy is NormalizationStrategy.UPPERCASE
- else expression.this.lower(),
+ (
+ expression.this.upper()
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
+ else expression.this.lower()
+ ),
)
return expression
@@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
"""
agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
+ lambda node: (
+ exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
)
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
@@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
- normalize = (
- lambda identifier: self.dialect.normalize_identifier(identifier).name
- if identifier
- else None
+ normalize = lambda identifier: (
+ self.dialect.normalize_identifier(identifier).name if identifier else None
)
targets = {normalize(expression.this.this)}
@@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
for when in expression.expressions:
when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
+ lambda node: (
+ exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node
+ ),
copy=False,
)