summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-08 05:38:42 +0000
commitc66e4a33e1a07c439f03fe47f146a6c6482bf6df (patch)
treecfdf01111c063b3e50841695e6c2768833aea4dc /sqlglot/dialects/dialect.py
parentReleasing debian version 20.11.0-1. (diff)
downloadsqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.tar.xz
sqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.zip
Merging upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py132
1 files changed, 87 insertions, 45 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 6be991b..6e2d190 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import logging
import typing as t
from enum import Enum, auto
from functools import reduce
@@ -7,7 +8,8 @@ from functools import reduce
from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
-from sqlglot.helper import AutoName, flatten, seq_get
+from sqlglot.helper import AutoName, flatten, is_int, seq_get
+from sqlglot.jsonpath import parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
@@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
if t.TYPE_CHECKING:
- from sqlglot._typing import B, E
+ from sqlglot._typing import B, E, F
+
+ JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar]
+
+logger = logging.getLogger("sqlglot")
class Dialects(str, Enum):
@@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- # Delimiters for quotes, identifiers and the corresponding escape characters
+ # Delimiters for string literals and identifiers
QUOTE_START = "'"
QUOTE_END = "'"
IDENTIFIER_START = '"'
@@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect):
"""
if (
isinstance(expression, exp.Identifier)
- and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE
+ and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE
and (
not expression.quoted
or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE
@@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect):
return expression
+ def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ if isinstance(path, exp.Literal):
+ path_text = path.name
+ if path.is_number:
+ path_text = f"[{path_text}]"
+
+ try:
+ return parse_json_path(path_text)
+ except ParseError as e:
+ logger.warning(f"Invalid JSON path syntax. {str(e)}")
+
+ return path
+
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)
@@ -500,14 +519,12 @@ def if_sql(
return _if_sql
-def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
- return self.binary(expression, "->")
-
+def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
+ this = expression.this
+ if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
+ this.replace(exp.cast(this, "json"))
-def arrow_json_extract_scalar_sql(
- self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
-) -> str:
- return self.binary(expression, "->>")
+ return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
def inline_array_sql(self: Generator, expression: exp.Array) -> str:
@@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str:
return self.cast_sql(expression)
-def no_properties_sql(self: Generator, expression: exp.Properties) -> str:
- self.unsupported("Properties unsupported")
- return ""
-
-
def no_comment_column_constraint_sql(
self: Generator, expression: exp.CommentColumnConstraint
) -> str:
@@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
-def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
- from sqlglot.optimizer.simplify import simplify
-
- # Makes sure the path will be evaluated correctly at runtime to include the path root.
- # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
- path = expression.expression
- path = exp.func(
- "if",
- exp.func("startswith", path, "'['"),
- exp.func("concat", "'$'", path),
- exp.func("concat", "'$.'", path),
- )
-
- expression.expression.replace(simplify(path))
- return expression
-
-
-def path_to_jsonpath(
- name: str = "JSON_EXTRACT",
-) -> t.Callable[[Generator, exp.GetPath], str]:
- def _transform(self: Generator, expression: exp.GetPath) -> str:
- return rename_func(name)(self, prepend_dollar_to_path(expression))
-
- return _transform
-
-
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
- normalize = lambda identifier: (
- self.dialect.normalize_identifier(identifier).name if identifier else None
- )
+ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
+ return self.dialect.normalize_identifier(identifier).name if identifier else None
targets = {normalize(expression.this.this)}
@@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
)
return self.merge_sql(expression)
+
+
+def parse_json_extract_path(
+ expr_type: t.Type[F], zero_based_indexing: bool = True
+) -> t.Callable[[t.List], F]:
+ def _parse_json_extract_path(args: t.List) -> F:
+ segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
+ for arg in args[1:]:
+ if not isinstance(arg, exp.Literal):
+ # We use the fallback parser because we can't really transpile non-literals safely
+ return expr_type.from_arg_list(args)
+
+ text = arg.name
+ if is_int(text):
+ index = int(text)
+ segments.append(
+ exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1)
+ )
+ else:
+ segments.append(exp.JSONPathKey(this=text))
+
+ # This is done to avoid failing in the expression validator due to the arg count
+ del args[2:]
+ return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
+
+ return _parse_json_extract_path
+
+
+def json_extract_segments(
+ name: str, quoted_index: bool = True
+) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]:
+ def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
+ path = expression.expression
+ if not isinstance(path, exp.JSONPath):
+ return rename_func(name)(self, expression)
+
+ segments = []
+ for segment in path.expressions:
+ path = self.sql(segment)
+ if path:
+ if isinstance(segment, exp.JSONPathPart) and (
+ quoted_index or not isinstance(segment, exp.JSONPathSubscript)
+ ):
+ path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}"
+
+ segments.append(path)
+
+ return self.func(name, expression.this, *segments)
+
+ return _json_extract_segments
+
+
+def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str:
+ if isinstance(expression.this, exp.JSONPathWildcard):
+ self.unsupported("Unsupported wildcard in JSONPathKey expression")
+
+ return expression.name