diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-02-08 05:38:42 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-02-08 05:38:42 +0000 |
commit | c66e4a33e1a07c439f03fe47f146a6c6482bf6df (patch) | |
tree | cfdf01111c063b3e50841695e6c2768833aea4dc /sqlglot/dialects/dialect.py | |
parent | Releasing debian version 20.11.0-1. (diff) | |
download | sqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.tar.xz sqlglot-c66e4a33e1a07c439f03fe47f146a6c6482bf6df.zip |
Merging upstream version 21.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 132 |
1 files changed, 87 insertions, 45 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6be991b..6e2d190 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from enum import Enum, auto from functools import reduce @@ -7,7 +8,8 @@ from functools import reduce from sqlglot import exp from sqlglot.errors import ParseError from sqlglot.generator import Generator -from sqlglot.helper import AutoName, flatten, seq_get +from sqlglot.helper import AutoName, flatten, is_int, seq_get +from sqlglot.jsonpath import parse as parse_json_path from sqlglot.parser import Parser from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType @@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] if t.TYPE_CHECKING: - from sqlglot._typing import B, E + from sqlglot._typing import B, E, F + + JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] + +logger = logging.getLogger("sqlglot") class Dialects(str, Enum): @@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect): INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} - # Delimiters for quotes, identifiers and the corresponding escape characters + # Delimiters for string literals and identifiers QUOTE_START = "'" QUOTE_END = "'" IDENTIFIER_START = '"' @@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect): """ if ( isinstance(expression, exp.Identifier) - and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE + and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE and ( not expression.quoted or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE @@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect): return expression + def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + if isinstance(path, exp.Literal): + path_text = path.name + if path.is_number: + path_text = f"[{path_text}]" + + try: + return parse_json_path(path_text) + except ParseError as e: + logger.warning(f"Invalid JSON path syntax. {str(e)}") + + return path + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenize(sql), sql) @@ -500,14 +519,12 @@ def if_sql( return _if_sql -def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: - return self.binary(expression, "->") - +def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: + this = expression.this + if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: + this.replace(exp.cast(this, "json")) -def arrow_json_extract_scalar_sql( - self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar -) -> str: - return self.binary(expression, "->>") + return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") def inline_array_sql(self: Generator, expression: exp.Array) -> str: @@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: return self.cast_sql(expression) -def no_properties_sql(self: Generator, expression: exp.Properties) -> str: - self.unsupported("Properties unsupported") - return "" - - def no_comment_column_constraint_sql( self: Generator, expression: exp.CommentColumnConstraint ) -> str: @@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return _delta_sql -def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: - from sqlglot.optimizer.simplify import simplify - - # Makes sure the path will be evaluated correctly at runtime to include the path root. - # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. - path = expression.expression - path = exp.func( - "if", - exp.func("startswith", path, "'['"), - exp.func("concat", "'$'", path), - exp.func("concat", "'$.'", path), - ) - - expression.expression.replace(simplify(path)) - return expression - - -def path_to_jsonpath( - name: str = "JSON_EXTRACT", -) -> t.Callable[[Generator, exp.GetPath], str]: - def _transform(self: Generator, expression: exp.GetPath) -> str: - return rename_func(name)(self, prepend_dollar_to_path(expression)) - - return _transform - - def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: trunc_curr_date = exp.func("date_trunc", "month", expression.this) plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") @@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: """Remove table refs from columns in when statements.""" alias = expression.this.args.get("alias") - normalize = lambda identifier: ( - self.dialect.normalize_identifier(identifier).name if identifier else None - ) + def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: + return self.dialect.normalize_identifier(identifier).name if identifier else None targets = {normalize(expression.this.this)} @@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: ) return self.merge_sql(expression) + + +def parse_json_extract_path( + expr_type: t.Type[F], zero_based_indexing: bool = True +) -> t.Callable[[t.List], F]: + def _parse_json_extract_path(args: t.List) -> F: + segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] + for arg in args[1:]: + if not isinstance(arg, exp.Literal): + # We use the fallback parser because we can't really transpile non-literals safely + return expr_type.from_arg_list(args) + + text = arg.name + if is_int(text): + index = int(text) + segments.append( + exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) + ) + else: + segments.append(exp.JSONPathKey(this=text)) + + # This is done to avoid failing in the expression validator due to the arg count + del args[2:] + return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) + + return _parse_json_extract_path + + +def json_extract_segments( + name: str, quoted_index: bool = True +) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: + def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: + path = expression.expression + if not isinstance(path, exp.JSONPath): + return rename_func(name)(self, expression) + + segments = [] + for segment in path.expressions: + path = self.sql(segment) + if path: + if isinstance(segment, exp.JSONPathPart) and ( + quoted_index or not isinstance(segment, exp.JSONPathSubscript) + ): + path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" + + segments.append(path) + + return self.func(name, expression.this, *segments) + + return _json_extract_segments + + +def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: + if isinstance(expression.this, exp.JSONPathWildcard): + self.unsupported("Unsupported wildcard in JSONPathKey expression") + + return expression.name |