diff options
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 88 |
1 files changed, 56 insertions, 32 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0440a99..b0a78d2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -17,12 +17,12 @@ from sqlglot.trie import new_trie DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] +JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] + if t.TYPE_CHECKING: from sqlglot._typing import B, E, F - JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] - logger = logging.getLogger("sqlglot") @@ -148,47 +148,53 @@ class _Dialect(type): class Dialect(metaclass=_Dialect): INDEX_OFFSET = 0 - """Determines the base index offset for arrays.""" + """The base index offset for arrays.""" WEEK_OFFSET = 0 - """Determines the day of week of DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" + """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" UNNEST_COLUMN_ONLY = False - """Determines whether or not `UNNEST` table aliases are treated as column aliases.""" + """Whether `UNNEST` table aliases are treated as column aliases.""" ALIAS_POST_TABLESAMPLE = False - """Determines whether or not the table alias comes after tablesample.""" + """Whether the table alias comes after tablesample.""" TABLESAMPLE_SIZE_IS_PERCENT = False - """Determines whether or not a size in the table sample clause represents percentage.""" + """Whether a size in the table sample clause represents percentage.""" NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE """Specifies the strategy according to which identifiers should be normalized.""" IDENTIFIERS_CAN_START_WITH_DIGIT = False - """Determines whether or not an unquoted identifier can start with a digit.""" + """Whether an unquoted identifier can start with a digit.""" DPIPE_IS_STRING_CONCAT = True - """Determines whether or not the DPIPE token (`||`) is a string concatenation operator.""" + """Whether the DPIPE token (`||`) is a string concatenation operator.""" STRICT_STRING_CONCAT = False - """Determines whether or not `CONCAT`'s arguments must be strings.""" + """Whether `CONCAT`'s arguments must be strings.""" SUPPORTS_USER_DEFINED_TYPES = True - """Determines whether or not user-defined data types are supported.""" + """Whether user-defined data types are supported.""" SUPPORTS_SEMI_ANTI_JOIN = True - """Determines whether or not `SEMI` or `ANTI` joins are supported.""" + """Whether `SEMI` or `ANTI` joins are supported.""" NORMALIZE_FUNCTIONS: bool | str = "upper" - """Determines how function names are going to be normalized.""" + """ + Determines how function names are going to be normalized. + Possible values: + "upper" or True: Convert names to uppercase. + "lower": Convert names to lowercase. + False: Disables function name normalization. + """ LOG_BASE_FIRST = True - """Determines whether the base comes first in the `LOG` function.""" + """Whether the base comes first in the `LOG` function.""" NULL_ORDERING = "nulls_are_small" """ - Indicates the default `NULL` ordering method to use if not explicitly set. + Default `NULL` ordering method to use if not explicitly set. Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` """ @@ -200,7 +206,7 @@ class Dialect(metaclass=_Dialect): """ SAFE_DIVISION = False - """Determines whether division by zero throws an error (`False`) or returns NULL (`True`).""" + """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" CONCAT_COALESCE = False """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" @@ -210,7 +216,7 @@ class Dialect(metaclass=_Dialect): TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" TIME_MAPPING: t.Dict[str, str] = {} - """Associates this dialect's time formats with their equivalent Python `strftime` format.""" + """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE @@ -418,7 +424,7 @@ class Dialect(metaclass=_Dialect): `"safe"`: Only returns `True` if the identifier is case-insensitive. Returns: - Whether or not the given text can be identified. + Whether the given text can be identified. """ if identify is True or identify == "always": return True @@ -614,7 +620,7 @@ def var_map_sql( return self.func(map_func_name, *args) -def format_time_lambda( +def build_formatted_time( exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None ) -> t.Callable[[t.List], E]: """Helper used for time expressions. @@ -628,7 +634,7 @@ def format_time_lambda( A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args: t.List): + def _builder(args: t.List): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( @@ -637,7 +643,7 @@ def format_time_lambda( ), ) - return _format_time + return _builder def time_format( @@ -654,23 +660,23 @@ def time_format( return _time_format -def parse_date_delta( +def build_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: - def inner_func(args: t.List) -> E: + def _builder(args: t.List) -> E: unit_based = len(args) == 3 this = args[2] if unit_based else seq_get(args, 0) unit = args[0] if unit_based else exp.Literal.string("DAY") unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit return exp_class(this=this, expression=seq_get(args, 1), unit=unit) - return inner_func + return _builder -def parse_date_delta_with_interval( +def build_date_delta_with_interval( expression_class: t.Type[E], ) -> t.Callable[[t.List], t.Optional[E]]: - def func(args: t.List) -> t.Optional[E]: + def _builder(args: t.List) -> t.Optional[E]: if len(args) < 2: return None @@ -687,7 +693,7 @@ def parse_date_delta_with_interval( this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) ) - return func + return _builder def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: @@ -888,7 +894,7 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: # Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects -def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: +def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) @@ -991,10 +997,10 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: return self.merge_sql(expression) -def parse_json_extract_path( +def build_json_extract_path( expr_type: t.Type[F], zero_based_indexing: bool = True ) -> t.Callable[[t.List], F]: - def _parse_json_extract_path(args: t.List) -> F: + def _builder(args: t.List) -> F: segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] for arg in args[1:]: if not isinstance(arg, exp.Literal): @@ -1014,11 +1020,11 @@ def parse_json_extract_path( del args[2:] return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) - return _parse_json_extract_path + return _builder def json_extract_segments( - name: str, quoted_index: bool = True + name: str, quoted_index: bool = True, op: t.Optional[str] = None ) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: path = expression.expression @@ -1036,6 +1042,8 @@ def json_extract_segments( segments.append(path) + if op: + return f" {op} ".join([self.sql(expression.this), *segments]) return self.func(name, expression.this, *segments) return _json_extract_segments @@ -1046,3 +1054,19 @@ def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str self.unsupported("Unsupported wildcard in JSONPathKey expression") return expression.name + + +def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: + cond = expression.expression + if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: + alias = cond.expressions[0] + cond = cond.this + elif isinstance(cond, exp.Predicate): + alias = "_u" + else: + self.unsupported("Unsupported filter condition") + return "" + + unnest = exp.Unnest(expressions=[expression.this]) + filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) + return self.sql(exp.Array(expressions=[filtered])) |