summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:11:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:12:02 +0000
commit8d36f5966675e23bee7026ba37ae0647fbf47300 (patch)
treedf4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/dialects/dialect.py
parentReleasing debian version 22.2.0-1. (diff)
downloadsqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz
sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py115
1 files changed, 91 insertions, 24 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 599505c..81057c2 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -31,6 +31,7 @@ class Dialects(str, Enum):
DIALECT = ""
+ ATHENA = "athena"
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
@@ -42,6 +43,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
+ PRQL = "prql"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
@@ -108,11 +110,18 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
- klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
+ base = seq_get(bases, 0)
+ base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
+ base_parser = (getattr(base, "parser_class", Parser),)
+ base_generator = (getattr(base, "generator_class", Generator),)
- klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
- klass.parser_class = getattr(klass, "Parser", Parser)
- klass.generator_class = getattr(klass, "Generator", Generator)
+ klass.tokenizer_class = klass.__dict__.get(
+ "Tokenizer", type("Tokenizer", base_tokenizer, {})
+ )
+ klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
+ klass.generator_class = klass.__dict__.get(
+ "Generator", type("Generator", base_generator, {})
+ )
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
@@ -134,9 +143,31 @@ class _Dialect(type):
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
+ if "\\" in klass.tokenizer_class.STRING_ESCAPES:
+ klass.UNESCAPED_SEQUENCES = {
+ "\\a": "\a",
+ "\\b": "\b",
+ "\\f": "\f",
+ "\\n": "\n",
+ "\\r": "\r",
+ "\\t": "\t",
+ "\\v": "\v",
+ "\\\\": "\\",
+ **klass.UNESCAPED_SEQUENCES,
+ }
+
+ klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
+
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
+ if enum not in ("", "databricks", "hive", "spark", "spark2"):
+ modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
+ for modifier in ("cluster", "distribute", "sort"):
+ modifier_transforms.pop(modifier, None)
+
+ klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
+
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
False: Disables function name normalization.
"""
- LOG_BASE_FIRST = True
- """Whether the base comes first in the `LOG` function."""
+ LOG_BASE_FIRST: t.Optional[bool] = True
+ """
+ Whether the base comes first in the `LOG` function.
+ Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
+ """
NULL_ORDERING = "nulls_are_small"
"""
@@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
"""
- ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- """Mapping of an unescaped escape sequence to the corresponding character."""
+ UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
+ """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
PSEUDOCOLUMNS: t.Set[str] = set()
"""
@@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
- INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+ ESCAPED_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for string literals and identifiers
QUOTE_START = "'"
@@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
return ""
-def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
+def str_position_sql(
+ self: Generator, expression: exp.StrPosition, generate_instance: bool = False
+) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
+ instance = expression.args.get("instance") if generate_instance else None
+ position_offset = ""
+
if position:
- return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
- return f"STRPOS({this}, {substr})"
+ # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
+ this = self.func("SUBSTR", this, position)
+ position_offset = f" + {position} - 1"
+
+ return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
@@ -689,9 +731,7 @@ def build_date_delta_with_interval(
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
- return expression_class(
- this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
- )
+ return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
return _builder
@@ -710,18 +750,14 @@ def date_add_interval_sql(
) -> t.Callable[[Generator, exp.Expression], str]:
def func(self: Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
- unit = expression.args.get("unit")
- unit = exp.var(unit.name.upper() if unit else "DAY")
- interval = exp.Interval(this=expression.expression, unit=unit)
+ interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
- return self.func(
- "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
- )
+ return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
@@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return self.func(
name,
- exp.var(expression.text("unit").upper() or "DAY"),
+ unit_to_var(expression),
expression.expression,
expression.this,
)
@@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
+def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
+ unit = expression.args.get("unit")
+
+ if isinstance(unit, exp.Placeholder):
+ return unit
+ if unit:
+ return exp.Literal.string(unit.name)
+ return exp.Literal.string(default) if default else None
+
+
+def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
+ unit = expression.args.get("unit")
+
+ if isinstance(unit, (exp.Var, exp.Placeholder)):
+ return unit
+ return exp.Var(this=default) if default else None
+
+
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
def build_json_extract_path(
- expr_type: t.Type[F], zero_based_indexing: bool = True
+ expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
) -> t.Callable[[t.List], F]:
def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
@@ -1018,7 +1072,11 @@ def build_json_extract_path(
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
- return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
+ return expr_type(
+ this=seq_get(args, 0),
+ expression=exp.JSONPath(expressions=segments),
+ only_json_types=arrow_req_json_type,
+ )
return _builder
@@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
unnest = exp.Unnest(expressions=[expression.this])
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
return self.sql(exp.Array(expressions=[filtered]))
+
+
+def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
+ return self.func(
+ "TO_NUMBER",
+ expression.this,
+ expression.args.get("format"),
+ expression.args.get("nlsparam"),
+ )