diff options
Diffstat (limited to 'sqlglot')
41 files changed, 1009 insertions, 434 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index d71c06d..2207a28 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 """ .. include:: ../README.md @@ -87,11 +88,13 @@ def parse( @t.overload -def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ... +def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: + ... @t.overload -def parse_one(sql: str, **opts) -> Expression: ... +def parse_one(sql: str, **opts) -> Expression: + ... def parse_one( diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py index 65f307e..0415aa4 100644 --- a/sqlglot/_typing.py +++ b/sqlglot/_typing.py @@ -13,4 +13,5 @@ if t.TYPE_CHECKING: A = t.TypeVar("A", bound=t.Any) B = t.TypeVar("B", bound="sqlglot.exp.Binary") E = t.TypeVar("E", bound="sqlglot.exp.Expression") +F = t.TypeVar("F", bound="sqlglot.exp.Func") T = t.TypeVar("T") diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 0bacbf9..7e3f07b 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -140,10 +140,12 @@ class DataFrame: return cte, name @t.overload - def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: + ... @t.overload - def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: + ... def _ensure_list_of_columns(self, cols): return Column.ensure_cols(ensure_list(cols)) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index a388cb4..29e7c55 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -368,7 +368,10 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls) + this = Column.invoke_expression_over_column(col, expression.First) + if ignorenulls: + return Column.invoke_expression_over_column(this, expression.IgnoreNulls) + return this def grouping_id(*cols: ColumnOrName) -> Column: @@ -392,7 +395,10 @@ def isnull(col: ColumnOrName) -> Column: def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls) + this = Column.invoke_expression_over_column(col, expression.Last) + if ignorenulls: + return Column.invoke_expression_over_column(this, expression.IgnoreNulls) + return this def monotonically_increasing_id() -> Column: @@ -485,31 +491,28 @@ def factorial(col: ColumnOrName) -> Column: def lag( col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None ) -> Column: - if default is not None: - return Column.invoke_anonymous_function(col, "LAG", offset, default) - if offset != 1: - return Column.invoke_anonymous_function(col, "LAG", offset) - return Column.invoke_anonymous_function(col, "LAG") + return Column.invoke_expression_over_column( + col, expression.Lag, offset=None if offset == 1 else offset, default=default + ) def lead( col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None ) -> Column: - if default is not None: - return Column.invoke_anonymous_function(col, "LEAD", offset, default) - if offset != 1: - return Column.invoke_anonymous_function(col, "LEAD", offset) - return Column.invoke_anonymous_function(col, "LEAD") + return Column.invoke_expression_over_column( + col, expression.Lead, offset=None if offset == 1 else offset, default=default + ) def nth_value( col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None ) -> Column: + this = Column.invoke_expression_over_column( + col, expression.NthValue, offset=None if offset == 1 else offset + ) if ignoreNulls is not None: - raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") - if offset != 1: - return Column.invoke_anonymous_function(col, "NTH_VALUE", offset) - return Column.invoke_anonymous_function(col, "NTH_VALUE") + return Column.invoke_expression_over_column(this, expression.IgnoreNulls) + return this def ntile(n: int) -> Column: diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 04990ac..82552c9 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -1,9 +1,10 @@ +# ruff: noqa: F401 """ ## Dialects While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible -SQL transpilation framework. +SQL transpilation framework. The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible. diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 771ae1a..9068235 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import ( min_or_least, no_ilike_sql, parse_date_delta_with_interval, - path_to_jsonpath, regexp_replace_sql, rename_func, timestrtotime_sql, @@ -458,8 +457,10 @@ class BigQuery(Dialect): return this - def _parse_table_parts(self, schema: bool = False) -> exp.Table: - table = super()._parse_table_parts(schema=schema) + def _parse_table_parts( + self, schema: bool = False, is_db_reference: bool = False + ) -> exp.Table: + table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference) if isinstance(table.this, exp.Identifier) and "." in table.name: catalog, db, this, *rest = ( t.cast(t.Optional[exp.Expression], exp.to_identifier(x)) @@ -474,10 +475,12 @@ class BigQuery(Dialect): return table @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: + ... @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: + ... def _parse_json_object(self, agg=False): json_object = super()._parse_json_object() @@ -536,6 +539,8 @@ class BigQuery(Dialect): UNPIVOT_ALIASES_ARE_IDENTIFIERS = False JSON_KEY_VALUE_PAIR_SEP = "," NULL_ORDERING_SUPPORTED = False + IGNORE_NULLS_IN_FUNC = True + JSON_PATH_SINGLE_QUOTE_ESCAPE = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -554,7 +559,8 @@ class BigQuery(Dialect): exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), - exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", + exp.DateDiff: lambda self, + e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.DateFromParts: rename_func("DATE"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: date_add_interval_sql("DATE", "SUB"), @@ -565,7 +571,6 @@ class BigQuery(Dialect): "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'" ), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), - exp.GetPath: path_to_jsonpath(), exp.GroupConcat: rename_func("STRING_AGG"), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql(false_value="NULL"), @@ -597,12 +602,13 @@ class BigQuery(Dialect): ] ), exp.SHA2: lambda self, e: self.func( - f"SHA256" if e.text("length") == "256" else "SHA512", e.this + "SHA256" if e.text("length") == "256" else "SHA512", e.this ), exp.StabilityProperty: lambda self, e: ( - f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" + "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" ), - exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", + exp.StrToDate: lambda self, + e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") ), @@ -610,9 +616,10 @@ class BigQuery(Dialect): exp.TimeFromParts: rename_func("TIME"), exp.TimeSub: date_add_interval_sql("TIME", "SUB"), exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), + exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"), exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), + exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: _ts_or_ds_diff_sql, exp.TsOrDsToTime: rename_func("TIME"), @@ -623,6 +630,12 @@ class BigQuery(Dialect): exp.VariancePop: rename_func("VAR_POP"), } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 1248edc..1ec15c5 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -8,12 +8,15 @@ from sqlglot.dialects.dialect import ( arg_max_or_min_no_count, date_delta_sql, inline_array_sql, + json_extract_segments, + json_path_key_only_name, no_pivot_sql, + parse_json_extract_path, rename_func, var_map_sql, ) from sqlglot.errors import ParseError -from sqlglot.helper import seq_get +from sqlglot.helper import is_int, seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import Token, TokenType @@ -120,6 +123,9 @@ class ClickHouse(Dialect): "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), + "JSONEXTRACTSTRING": parse_json_extract_path( + exp.JSONExtractScalar, zero_based_indexing=False + ), "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, "RANDCANONICAL": exp.Rand.from_arg_list, @@ -354,9 +360,14 @@ class ClickHouse(Dialect): joins: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, + is_db_reference: bool = False, ) -> t.Optional[exp.Expression]: this = super()._parse_table( - schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket + schema=schema, + joins=joins, + alias_tokens=alias_tokens, + parse_bracket=parse_bracket, + is_db_reference=is_db_reference, ) if self._match(TokenType.FINAL): @@ -518,6 +529,12 @@ class ClickHouse(Dialect): exp.DataType.Type.VARCHAR: "String", } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, **STRING_TYPE_MAPPING, @@ -570,6 +587,10 @@ class ClickHouse(Dialect): exp.Explode: rename_func("arrayJoin"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.IsNan: rename_func("isNaN"), + exp.JSONExtract: json_extract_segments("JSONExtractString", quoted_index=False), + exp.JSONExtractScalar: json_extract_segments("JSONExtractString", quoted_index=False), + exp.JSONPathKey: json_path_key_only_name, + exp.JSONPathRoot: lambda *_: "", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Nullif: rename_func("nullIf"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -579,7 +600,8 @@ class ClickHouse(Dialect): exp.Rand: rename_func("randCanonical"), exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.StartsWith: rename_func("startsWith"), - exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", + exp.StrPosition: lambda self, + e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), } @@ -608,6 +630,13 @@ class ClickHouse(Dialect): "NAMED COLLECTION", } + def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: + this = self.json_path_part(expression.this) + return str(int(this) + 1) if is_int(this) else this + + def likeproperty_sql(self, expression: exp.LikeProperty) -> str: + return f"AS {self.sql(expression, 'this')}" + def _any_to_has( self, expression: exp.EQ | exp.NEQ, diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 8e55b6a..20907db 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -22,6 +22,7 @@ class Databricks(Spark): "DATEADD": parse_date_delta(exp.DateAdd), "DATE_ADD": parse_date_delta(exp.DateAdd), "DATEDIFF": parse_date_delta(exp.DateDiff), + "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff), } FACTOR = { @@ -48,6 +49,9 @@ class Databricks(Spark): exp.DatetimeDiff: lambda self, e: self.func( "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this ), + exp.TimestampDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + ), exp.DatetimeTrunc: timestamptrunc_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6be991b..6e2d190 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from enum import Enum, auto from functools import reduce @@ -7,7 +8,8 @@ from functools import reduce from sqlglot import exp from sqlglot.errors import ParseError from sqlglot.generator import Generator -from sqlglot.helper import AutoName, flatten, seq_get +from sqlglot.helper import AutoName, flatten, is_int, seq_get +from sqlglot.jsonpath import parse as parse_json_path from sqlglot.parser import Parser from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType @@ -17,7 +19,11 @@ DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsD DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] if t.TYPE_CHECKING: - from sqlglot._typing import B, E + from sqlglot._typing import B, E, F + + JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] + +logger = logging.getLogger("sqlglot") class Dialects(str, Enum): @@ -256,7 +262,7 @@ class Dialect(metaclass=_Dialect): INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} - # Delimiters for quotes, identifiers and the corresponding escape characters + # Delimiters for string literals and identifiers QUOTE_START = "'" QUOTE_END = "'" IDENTIFIER_START = '"' @@ -373,7 +379,7 @@ class Dialect(metaclass=_Dialect): """ if ( isinstance(expression, exp.Identifier) - and not self.normalization_strategy is NormalizationStrategy.CASE_SENSITIVE + and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE and ( not expression.quoted or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE @@ -440,6 +446,19 @@ class Dialect(metaclass=_Dialect): return expression + def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + if isinstance(path, exp.Literal): + path_text = path.name + if path.is_number: + path_text = f"[{path_text}]" + + try: + return parse_json_path(path_text) + except ParseError as e: + logger.warning(f"Invalid JSON path syntax. {str(e)}") + + return path + def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: return self.parser(**opts).parse(self.tokenize(sql), sql) @@ -500,14 +519,12 @@ def if_sql( return _if_sql -def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str: - return self.binary(expression, "->") - +def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: + this = expression.this + if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: + this.replace(exp.cast(this, "json")) -def arrow_json_extract_scalar_sql( - self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar -) -> str: - return self.binary(expression, "->>") + return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") def inline_array_sql(self: Generator, expression: exp.Array) -> str: @@ -552,11 +569,6 @@ def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: return self.cast_sql(expression) -def no_properties_sql(self: Generator, expression: exp.Properties) -> str: - self.unsupported("Properties unsupported") - return "" - - def no_comment_column_constraint_sql( self: Generator, expression: exp.CommentColumnConstraint ) -> str: @@ -965,32 +977,6 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return _delta_sql -def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: - from sqlglot.optimizer.simplify import simplify - - # Makes sure the path will be evaluated correctly at runtime to include the path root. - # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. - path = expression.expression - path = exp.func( - "if", - exp.func("startswith", path, "'['"), - exp.func("concat", "'$'", path), - exp.func("concat", "'$.'", path), - ) - - expression.expression.replace(simplify(path)) - return expression - - -def path_to_jsonpath( - name: str = "JSON_EXTRACT", -) -> t.Callable[[Generator, exp.GetPath], str]: - def _transform(self: Generator, expression: exp.GetPath) -> str: - return rename_func(name)(self, prepend_dollar_to_path(expression)) - - return _transform - - def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: trunc_curr_date = exp.func("date_trunc", "month", expression.this) plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") @@ -1003,9 +989,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: """Remove table refs from columns in when statements.""" alias = expression.this.args.get("alias") - normalize = lambda identifier: ( - self.dialect.normalize_identifier(identifier).name if identifier else None - ) + def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: + return self.dialect.normalize_identifier(identifier).name if identifier else None targets = {normalize(expression.this.this)} @@ -1023,3 +1008,60 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: ) return self.merge_sql(expression) + + +def parse_json_extract_path( + expr_type: t.Type[F], zero_based_indexing: bool = True +) -> t.Callable[[t.List], F]: + def _parse_json_extract_path(args: t.List) -> F: + segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] + for arg in args[1:]: + if not isinstance(arg, exp.Literal): + # We use the fallback parser because we can't really transpile non-literals safely + return expr_type.from_arg_list(args) + + text = arg.name + if is_int(text): + index = int(text) + segments.append( + exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) + ) + else: + segments.append(exp.JSONPathKey(this=text)) + + # This is done to avoid failing in the expression validator due to the arg count + del args[2:] + return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) + + return _parse_json_extract_path + + +def json_extract_segments( + name: str, quoted_index: bool = True +) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: + def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: + path = expression.expression + if not isinstance(path, exp.JSONPath): + return rename_func(name)(self, expression) + + segments = [] + for segment in path.expressions: + path = self.sql(segment) + if path: + if isinstance(segment, exp.JSONPathPart) and ( + quoted_index or not isinstance(segment, exp.JSONPathSubscript) + ): + path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" + + segments.append(path) + + return self.func(name, expression.this, *segments) + + return _json_extract_segments + + +def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: + if isinstance(expression.this, exp.JSONPathWildcard): + self.unsupported("Unsupported wildcard in JSONPathKey expression") + + return expression.name diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 6e229b3..7a18e8e 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -55,11 +55,14 @@ class Doris(MySQL): exp.Map: rename_func("ARRAY_MAP"), exp.RegexpLike: rename_func("REGEXP"), exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), - exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToUnix: lambda self, + e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Split: rename_func("SPLIT_BY_STRING"), exp.TimeStrToDate: rename_func("TO_DATE"), - exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level + exp.ToChar: lambda self, + e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TsOrDsAdd: lambda self, + e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimestampTrunc: lambda self, e: self.func( diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 6bca9e7..be23355 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -99,6 +99,7 @@ class Drill(Dialect): QUERY_HINTS = False NVL2_SUPPORTED = False LAST_DAY_SUPPORTS_DATE_PART = False + SUPPORTS_CREATE_TABLE_LIKE = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -128,10 +129,14 @@ class Drill(Dialect): exp.DateAdd: _date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), - exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", - exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})", - exp.If: lambda self, e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", - exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", + exp.DateToDi: lambda self, + e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, + e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})", + exp.If: lambda self, + e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", + exp.ILike: lambda self, + e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.RegexpLike: rename_func("REGEXP_MATCHES"), @@ -141,7 +146,8 @@ class Drill(Dialect): exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, + e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -149,8 +155,10 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", - exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: lambda self, + e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", + exp.TsOrDiToDi: lambda self, + e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } def normalize_func(self, name: str) -> str: diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f55ad70..d7ba729 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import ( NormalizationStrategy, approx_count_distinct_sql, arg_max_or_min_no_count, - arrow_json_extract_scalar_sql, arrow_json_extract_sql, binary_from_function, bool_xor_sql, @@ -18,11 +17,9 @@ from sqlglot.dialects.dialect import ( format_time_lambda, inline_array_sql, no_comment_column_constraint_sql, - no_properties_sql, no_safe_divide_sql, no_timestamp_sql, pivot_column_names, - prepend_dollar_to_path, regexp_extract_sql, rename_func, str_position_sql, @@ -172,6 +169,18 @@ class DuckDB(Dialect): # https://duckdb.org/docs/sql/introduction.html#creating-a-new-table NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE + def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + if isinstance(path, exp.Literal): + # DuckDB also supports the JSON pointer syntax, where every path starts with a `/`. + # Additionally, it allows accessing the back of lists using the `[#-i]` syntax. + # This check ensures we'll avoid trying to parse these as JSON paths, which can + # either result in a noisy warning or in an invalid representation of the path. + path_text = path.name + if path_text.startswith("/") or "[#" in path_text: + return path + + return super().to_json_path(path) + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -229,6 +238,8 @@ class DuckDB(Dialect): this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), "JSON": exp.ParseJSON.from_arg_list, + "JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract), + "JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar), "LIST_HAS": exp.ArrayContains.from_arg_list, "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, @@ -319,6 +330,9 @@ class DuckDB(Dialect): TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" LAST_DAY_SUPPORTS_DATE_PART = False JSON_KEY_VALUE_PAIR_SEP = "," + IGNORE_NULLS_IN_FUNC = True + JSON_PATH_BRACKETED_KEY_SUPPORTED = False + SUPPORTS_CREATE_TABLE_LIKE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -350,18 +364,18 @@ class DuckDB(Dialect): "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, - exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", + exp.DateToDi: lambda self, + e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False), - exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", + exp.DiToDate: lambda self, + e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False), exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.IsInf: rename_func("ISINF"), exp.IsNan: rename_func("ISNAN"), - exp.JSONBExtract: arrow_json_extract_sql, - exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtract: arrow_json_extract_sql, - exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONFormat: _json_format_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), @@ -377,7 +391,6 @@ class DuckDB(Dialect): # DuckDB doesn't allow qualified columns inside of PIVOT expressions. # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62 exp.Pivot: transforms.preprocess([transforms.unqualify_columns]), - exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, exp.RegexpReplace: lambda self, e: self.func( "REGEXP_REPLACE", @@ -395,7 +408,8 @@ class DuckDB(Dialect): exp.StrPosition: str_position_sql, exp.StrToDate: lambda self, e: f"CAST({str_to_time_sql(self, e)} AS DATE)", exp.StrToTime: str_to_time_sql, - exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StrToUnix: lambda self, + e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, exp.Timestamp: no_timestamp_sql, exp.TimestampDiff: lambda self, e: self.func( @@ -405,9 +419,11 @@ class DuckDB(Dialect): exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", - exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, + e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("EPOCH"), - exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", + exp.TsOrDiToDi: lambda self, + e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: lambda self, e: self.func( "DATE_DIFF", @@ -415,7 +431,8 @@ class DuckDB(Dialect): exp.cast(e.expression, "TIMESTAMP"), exp.cast(e.this, "TIMESTAMP"), ), - exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToStr: lambda self, + e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", exp.VariancePop: rename_func("VAR_POP"), @@ -423,6 +440,13 @@ class DuckDB(Dialect): exp.Xor: bool_xor_sql, } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + exp.JSONPathWildcard, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BINARY: "BLOB", @@ -442,11 +466,18 @@ class DuckDB(Dialect): UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren) + # DuckDB doesn't generally support CREATE TABLE .. properties + # https://duckdb.org/docs/sql/statements/create_table.html PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + prop: exp.Properties.Location.UNSUPPORTED + for prop in generator.Generator.PROPERTIES_LOCATION } + # There are a few exceptions (e.g. temporary tables) which are supported or + # can be transpiled to DuckDB, so we explicitly override them accordingly + PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA + PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE + def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: nano = expression.args.get("nano") if nano is not None: @@ -486,10 +517,6 @@ class DuckDB(Dialect): expression, sep=sep, tablesample_keyword=tablesample_keyword ) - def getpath_sql(self, expression: exp.GetPath) -> str: - expression = prepend_dollar_to_path(expression) - return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}" - def interval_sql(self, expression: exp.Interval) -> str: multiplier: t.Optional[int] = None unit = expression.text("unit").lower() diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 060f9bd..6337ffd 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -192,6 +192,18 @@ def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: return f"TO_DATE({this})" +def _parse_ignore_nulls( + exp_class: t.Type[exp.Expression], +) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: + def _parse(args: t.List[exp.Expression]) -> exp.Expression: + this = exp_class(this=seq_get(args, 0)) + if seq_get(args, 1) == exp.true(): + return exp.IgnoreNulls(this=this) + return this + + return _parse + + class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True @@ -298,8 +310,12 @@ class Hive(Dialect): expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "FIRST": _parse_ignore_nulls(exp.First), + "FIRST_VALUE": _parse_ignore_nulls(exp.FirstValue), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, + "LAST": _parse_ignore_nulls(exp.Last), + "LAST_VALUE": _parse_ignore_nulls(exp.LastValue), "LOCATE": locate_to_strposition, "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), @@ -429,6 +445,7 @@ class Hive(Dialect): EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False LAST_DAY_SUPPORTS_DATE_PART = False + JSON_PATH_SINGLE_QUOTE_ESCAPE = True EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Insert, @@ -437,6 +454,13 @@ class Hive(Dialect): exp.Union, } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + exp.JSONPathWildcard, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BIT: "BOOLEAN", @@ -471,9 +495,12 @@ class Hive(Dialect): exp.DateDiff: _date_diff_sql, exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _add_date_sql, - exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", - exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", + exp.DateToDi: lambda self, + e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", + exp.DiToDate: lambda self, + e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", + exp.FileFormatProperty: lambda self, + e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql(), exp.ILike: no_ilike_sql, @@ -502,7 +529,8 @@ class Hive(Dialect): exp.SafeDivide: no_safe_divide_sql, exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), - exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", + exp.Split: lambda self, + e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_time_sql, @@ -514,7 +542,8 @@ class Hive(Dialect): exp.TimeToStr: _time_to_str, exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToBase64: rename_func("BASE64"), - exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", + exp.TsOrDiToDi: lambda self, + e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _add_date_sql, exp.TsOrDsDiff: _date_diff_sql, exp.TsOrDsToDate: _to_date_sql, @@ -528,8 +557,10 @@ class Hive(Dialect): exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.National: lambda self, e: self.national_sql(e, prefix=""), - exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", - exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", + exp.ClusteredColumnConstraint: lambda self, + e: f"({self.expressions(e, 'this', indent=False)})", + exp.NonClusteredColumnConstraint: lambda self, + e: f"({self.expressions(e, 'this', indent=False)})", exp.NotForReplicationColumnConstraint: lambda self, e: "", exp.OnProperty: lambda self, e: "", exp.PrimaryKeyColumnConstraint: lambda self, e: "PRIMARY KEY", @@ -543,6 +574,13 @@ class Hive(Dialect): exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED, } + def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: + if isinstance(expression.this, exp.JSONPathWildcard): + self.unsupported("Unsupported wildcard in JSONPathKey expression") + return "" + + return super()._jsonpathkey_sql(expression) + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: # Hive has no temporary storage provider (there are hive settings though) return expression diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 21a9657..661ef7d 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -6,7 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, - arrow_json_extract_scalar_sql, + arrow_json_extract_sql, date_add_interval_sql, datestrtodate_sql, format_time_lambda, @@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_tablesample_sql, no_trycast_sql, + parse_date_delta, parse_date_delta_with_interval, - path_to_jsonpath, rename_func, strposition_to_locate_sql, ) @@ -306,6 +306,7 @@ class MySQL(Dialect): format=exp.Literal.string("%B"), ), "STR_TO_DATE": _str_to_date, + "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff), "TO_DAYS": lambda args: exp.paren( exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 0)), @@ -357,6 +358,7 @@ class MySQL(Dialect): "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True), "CREATE VIEW": _show_parser("CREATE VIEW", target=True), "DATABASES": _show_parser("DATABASES"), + "SCHEMAS": _show_parser("DATABASES"), "ENGINE": _show_parser("ENGINE", target=True), "STORAGE ENGINES": _show_parser("ENGINES"), "ENGINES": _show_parser("ENGINES"), @@ -630,6 +632,8 @@ class MySQL(Dialect): VALUES_AS_TABLE = False NVL2_SUPPORTED = False LAST_DAY_SUPPORTS_DATE_PART = False + JSON_TYPE_REQUIRED_FOR_EXTRACTION = True + JSON_PATH_BRACKETED_KEY_SUPPORTED = False JSON_KEY_VALUE_PAIR_SEP = "," TRANSFORMS = { @@ -646,10 +650,10 @@ class MySQL(Dialect): exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), - exp.GetPath: path_to_jsonpath(), - exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", + exp.GroupConcat: lambda self, + e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, - exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONExtractScalar: arrow_json_extract_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Month: _remove_ts_or_ds_to_date(), @@ -672,6 +676,9 @@ class MySQL(Dialect): exp.TableSample: no_tablesample_sql, exp.TimeFromParts: rename_func("MAKETIME"), exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), + exp.TimestampDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + ), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 4591d59..0c0d750 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -199,7 +199,8 @@ class Oracle(Dialect): transforms.eliminate_qualify, ] ), - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, + e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), @@ -208,7 +209,8 @@ class Oracle(Dialect): exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, - exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", + exp.UnixToTime: lambda self, + e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 87f6b02..0404c78 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -7,11 +7,11 @@ from sqlglot.dialects.dialect import ( DATE_ADD_OR_SUB, Dialect, any_value_to_max_sql, - arrow_json_extract_scalar_sql, - arrow_json_extract_sql, bool_xor_sql, datestrtodate_sql, format_time_lambda, + json_extract_segments, + json_path_key_only_name, max_or_greatest, merge_without_target_sql, min_or_least, @@ -20,6 +20,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_pivot_sql, no_trycast_sql, + parse_json_extract_path, parse_timestamp_trunc, rename_func, str_position_sql, @@ -292,6 +293,8 @@ class Postgres(Dialect): **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, "GENERATE_SERIES": _generate_series, + "JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract), + "JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar), "MAKE_TIME": exp.TimeFromParts.from_arg_list, "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list, "NOW": exp.CurrentTimestamp.from_arg_list, @@ -375,8 +378,15 @@ class Postgres(Dialect): TABLESAMPLE_SIZE_IS_ROWS = False TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" SUPPORTS_SELECT_INTO = True - # https://www.postgresql.org/docs/current/sql-createtable.html + JSON_TYPE_REQUIRED_FOR_EXTRACTION = True SUPPORTS_UNLOGGED_TABLES = True + LIKE_PROPERTY_INSIDE_SCHEMA = True + + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -412,11 +422,14 @@ class Postgres(Dialect): exp.DateSub: _date_add_sql("-"), exp.Explode: rename_func("UNNEST"), exp.GroupConcat: _string_agg_sql, - exp.JSONExtract: arrow_json_extract_sql, - exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"), + exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), + exp.JSONPathKey: json_path_key_only_name, + exp.JSONPathRoot: lambda *_: "", + exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this), exp.LastDay: no_last_day_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), @@ -443,7 +456,8 @@ class Postgres(Dialect): ] ), exp.StrPosition: str_position_sql, - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, + e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, exp.TimeFromParts: rename_func("MAKE_TIME"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6cc6030..8691192 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_safe_divide_sql, no_timestamp_sql, - path_to_jsonpath, regexp_extract_sql, rename_func, right_to_substring_sql, @@ -150,7 +149,7 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression: return expression -def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str: +def _first_last_sql(self: Presto.Generator, expression: exp.Func) -> str: """ Trino doesn't support FIRST / LAST as functions, but they're valid in the context of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases @@ -292,6 +291,7 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") LIMIT_ONLY_LITERALS = True SUPPORTS_SINGLE_ARG_CONCAT = False + LIKE_PROPERTY_INSIDE_SCHEMA = True PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, @@ -324,12 +324,18 @@ class Presto(Dialect): exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), exp.ArrayUniqueAgg: rename_func("SET_AGG"), - exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.AtTimeZone: rename_func("AT_TIMEZONE"), + exp.BitwiseAnd: lambda self, + e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseLeftShift: lambda self, + e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", - exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseOr: lambda self, + e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseRightShift: lambda self, + e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseXor: lambda self, + e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( @@ -344,7 +350,8 @@ class Presto(Dialect): "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, - exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", + exp.DateToDi: lambda self, + e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", exp.Literal.string(e.text("unit") or "DAY"), @@ -352,12 +359,14 @@ class Presto(Dialect): e.this, ), exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), - exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", + exp.DiToDate: lambda self, + e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, - exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'", - exp.GetPath: path_to_jsonpath(), + exp.FirstValue: _first_last_sql, + exp.FromTimeZone: lambda self, + e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.GroupConcat: lambda self, e: self.func( "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator") @@ -368,6 +377,7 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.ParseJSON: rename_func("JSON_PARSE"), exp.Last: _first_last_sql, + exp.LastValue: _first_last_sql, exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this), exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, @@ -394,26 +404,33 @@ class Presto(Dialect): exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, - exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StrToUnix: lambda self, + e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.Timestamp: no_timestamp_sql, exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))", - exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToUnix: lambda self, + e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.TIME_FORMAT}))", + exp.TimeToStr: lambda self, + e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), - exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.ToChar: lambda self, + e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), - exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + exp.TsOrDiToDi: lambda self, + e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: _ts_or_ds_diff_sql, exp.TsOrDsToDate: _ts_or_ds_to_date_sql, exp.Unhex: rename_func("FROM_HEX"), - exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToStr: lambda self, + e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: _unix_to_time_sql, - exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", + exp.UnixToTimeStr: lambda self, + e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), exp.WithinGroup: transforms.preprocess( diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 7194d81..a64c1d4 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( concat_ws_to_dpipe_sql, date_delta_sql, generatedasidentitycolumnconstraint_sql, + json_extract_segments, no_tablesample_sql, rename_func, ) @@ -20,10 +21,6 @@ if t.TYPE_CHECKING: from sqlglot._typing import E -def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: - return f'{self.sql(expression, "this")}."{expression.expression.name}"' - - def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: def _parse_delta(args: t.List) -> E: expr = expr_type(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) @@ -62,6 +59,7 @@ class Redshift(Postgres): "DATE_ADD": _parse_date_delta(exp.TsOrDsAdd), "DATEDIFF": _parse_date_delta(exp.TsOrDsDiff), "DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff), + "GETDATE": exp.CurrentTimestamp.from_arg_list, "LISTAGG": exp.GroupConcat.from_arg_list, "STRTOL": exp.FromBase.from_arg_list, } @@ -69,6 +67,7 @@ class Redshift(Postgres): NO_PAREN_FUNCTION_PARSERS = { **Postgres.Parser.NO_PAREN_FUNCTION_PARSERS, "APPROXIMATE": lambda self: self._parse_approximate_count(), + "SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True), } def _parse_table( @@ -77,6 +76,7 @@ class Redshift(Postgres): joins: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, + is_db_reference: bool = False, ) -> t.Optional[exp.Expression]: # Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr` unpivot = self._match(TokenType.UNPIVOT) @@ -85,6 +85,7 @@ class Redshift(Postgres): joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket, + is_db_reference=is_db_reference, ) return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table @@ -153,7 +154,6 @@ class Redshift(Postgres): **Postgres.Tokenizer.KEYWORDS, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, - "SYSDATE": TokenType.CURRENT_TIMESTAMP, "TOP": TokenType.TOP, "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, @@ -180,31 +180,29 @@ class Redshift(Postgres): exp.DataType.Type.VARBINARY: "VARBYTE", } - PROPERTIES_LOCATION = { - **Postgres.Generator.PROPERTIES_LOCATION, - exp.LikeProperty: exp.Properties.Location.POST_WITH, - } - TRANSFORMS = { **Postgres.Generator.TRANSFORMS, exp.Concat: concat_to_dpipe_sql, exp.ConcatWs: concat_ws_to_dpipe_sql, - exp.ApproxDistinct: lambda self, e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})", - exp.CurrentTimestamp: lambda self, e: "SYSDATE", + exp.ApproxDistinct: lambda self, + e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})", + exp.CurrentTimestamp: lambda self, e: ( + "SYSDATE" if e.args.get("transaction") else "GETDATE()" + ), exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.FromBase: rename_func("STRTOL"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, - exp.JSONExtract: _json_sql, - exp.JSONExtractScalar: _json_sql, + exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), exp.GroupConcat: rename_func("LISTAGG"), exp.ParseJSON: rename_func("JSON_PARSE"), exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), - exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.SortKeyProperty: lambda self, + e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.TableSample: no_tablesample_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD"), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), @@ -228,6 +226,13 @@ class Redshift(Postgres): """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" return self.properties(properties, prefix=" ", suffix="") + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + if expression.is_type(exp.DataType.Type.JSON): + # Redshift doesn't support a JSON type, so casting to it is treated as a noop + return self.sql(expression, "this") + + return super().cast_sql(expression, safe_prefix=safe_prefix) + def datatype_sql(self, expression: exp.DataType) -> str: """ Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 281167d..37f9761 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -21,19 +21,13 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import seq_get +from sqlglot.helper import is_int, seq_get from sqlglot.tokens import TokenType if t.TYPE_CHECKING: from sqlglot._typing import E -def _check_int(s: str) -> bool: - if s[0] in ("-", "+"): - return s[1:].isdigit() - return s.isdigit() - - # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: if len(args) == 2: @@ -53,7 +47,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, return exp.TimeStrToTime.from_arg_list(args) if first_arg.is_string: - if _check_int(first_arg.this): + if is_int(first_arg.this): # case: <integer> return exp.UnixToTime.from_arg_list(args) @@ -241,7 +235,6 @@ DATE_PART_MAPPING = { "NSECOND": "NANOSECOND", "NSECONDS": "NANOSECOND", "NANOSECS": "NANOSECOND", - "NSECONDS": "NANOSECOND", "EPOCH": "EPOCH_SECOND", "EPOCH_SECONDS": "EPOCH_SECOND", "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", @@ -291,7 +284,9 @@ def _parse_colon_get_path( path = exp.Literal.string(path.sql(dialect="snowflake")) # The extraction operator : is left-associative - this = self.expression(exp.GetPath, this=this, expression=path) + this = self.expression( + exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + ) if target_type: this = exp.cast(this, target_type) @@ -411,6 +406,9 @@ class Snowflake(Dialect): "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "FLATTEN": exp.Explode.from_arg_list, + "GET_PATH": lambda args, dialect: exp.JSONExtract( + this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) + ), "IFF": exp.If.from_arg_list, "LAST_DAY": lambda args: exp.LastDay( this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) @@ -474,6 +472,8 @@ class Snowflake(Dialect): "TERSE SCHEMAS": _show_parser("SCHEMAS"), "OBJECTS": _show_parser("OBJECTS"), "TERSE OBJECTS": _show_parser("OBJECTS"), + "TABLES": _show_parser("TABLES"), + "TERSE TABLES": _show_parser("TABLES"), "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "COLUMNS": _show_parser("COLUMNS"), @@ -534,7 +534,9 @@ class Snowflake(Dialect): return table - def _parse_table_parts(self, schema: bool = False) -> exp.Table: + def _parse_table_parts( + self, schema: bool = False, is_db_reference: bool = False + ) -> exp.Table: # https://docs.snowflake.com/en/user-guide/querying-stage if self._match(TokenType.STRING, advance=False): table = self._parse_string() @@ -550,7 +552,9 @@ class Snowflake(Dialect): self._match(TokenType.L_PAREN) while self._curr and not self._match(TokenType.R_PAREN): if self._match_text_seq("FILE_FORMAT", "=>"): - file_format = self._parse_string() or super()._parse_table_parts() + file_format = self._parse_string() or super()._parse_table_parts( + is_db_reference=is_db_reference + ) elif self._match_text_seq("PATTERN", "=>"): pattern = self._parse_string() else: @@ -560,7 +564,7 @@ class Snowflake(Dialect): table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern) else: - table = super()._parse_table_parts(schema=schema) + table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference) return self._parse_at_before(table) @@ -587,6 +591,8 @@ class Snowflake(Dialect): # which is syntactically valid but has no effect on the output terse = self._tokens[self._index - 2].text.upper() == "TERSE" + history = self._match_text_seq("HISTORY") + like = self._parse_string() if self._match(TokenType.LIKE) else None if self._match(TokenType.IN): @@ -597,7 +603,7 @@ class Snowflake(Dialect): if self._curr: scope = self._parse_table_parts() elif self._curr: - scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE" + scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE" scope = self._parse_table_parts() return self.expression( @@ -605,6 +611,7 @@ class Snowflake(Dialect): **{ "terse": terse, "this": this, + "history": history, "like": like, "scope": scope, "scope_kind": scope_kind, @@ -715,8 +722,10 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), - exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]", + exp.JSONExtract: rename_func("GET_PATH"), + exp.JSONExtractScalar: rename_func("JSON_EXTRACT_PATH_TEXT"), exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), + exp.JSONPathRoot: lambda *_: "", exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -745,7 +754,8 @@ class Snowflake(Dialect): exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, + e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Struct: lambda self, e: self.func( "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), @@ -771,6 +781,12 @@ class Snowflake(Dialect): exp.Xor: rename_func("BOOLXOR"), } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", @@ -841,6 +857,7 @@ class Snowflake(Dialect): def show_sql(self, expression: exp.Show) -> str: terse = "TERSE " if expression.args.get("terse") else "" + history = " HISTORY" if expression.args.get("history") else "" like = self.sql(expression, "like") like = f" LIKE {like}" if like else "" @@ -861,9 +878,7 @@ class Snowflake(Dialect): if from_: from_ = f" FROM {from_}" - return ( - f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}" - ) + return f"SHOW {terse}{expression.name}{history}{like}{scope_kind}{scope}{starts_with}{limit}{from_}" def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 624f76e..4c5c131 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import exp from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.hive import _parse_ignore_nulls from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get @@ -45,9 +46,7 @@ class Spark(Spark2): class Parser(Spark2.Parser): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, - "ANY_VALUE": lambda args: exp.AnyValue( - this=seq_get(args, 0), ignore_nulls=seq_get(args, 1) - ), + "ANY_VALUE": _parse_ignore_nulls(exp.AnyValue), "DATEDIFF": _parse_datediff, } diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index e4bb30e..9378d99 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -187,8 +187,10 @@ class Spark2(Hive): TRANSFORMS = { **Hive.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.ArraySum: lambda self, + e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.AtTimeZone: lambda self, + e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), exp.DateFromParts: rename_func("MAKE_DATE"), @@ -198,7 +200,8 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), - exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.FromTimeZone: lambda self, + e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, @@ -212,7 +215,8 @@ class Spark2(Hive): e.args.get("position"), ), exp.StrToDate: _str_to_date, - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToTime: lambda self, + e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimestampTrunc: lambda self, e: self.func( "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this ), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 244a96e..b292c81 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -7,7 +7,6 @@ from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, any_value_to_max_sql, - arrow_json_extract_scalar_sql, arrow_json_extract_sql, concat_to_dpipe_sql, count_if_to_sum, @@ -28,6 +27,12 @@ def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str: return self.func("DATE", expression.this, modifier) +def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> str: + if expression.expressions: + return self.function_fallback_sql(expression) + return arrow_json_extract_sql(self, expression) + + def _transform_create(expression: exp.Expression) -> exp.Expression: """Move primary key to a column and enforce auto_increment on primary keys.""" schema = expression.this @@ -85,6 +90,14 @@ class SQLite(Dialect): TABLE_HINTS = False QUERY_HINTS = False NVL2_SUPPORTED = False + JSON_PATH_BRACKETED_KEY_SUPPORTED = False + SUPPORTS_CREATE_TABLE_LIKE = False + + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -120,10 +133,8 @@ class SQLite(Dialect): exp.DateAdd: _date_add_sql, exp.DateStrToDate: lambda self, e: self.sql(e, "this"), exp.ILike: no_ilike_sql, - exp.JSONExtract: arrow_json_extract_sql, - exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.JSONBExtract: arrow_json_extract_sql, - exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONExtract: _json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_sql, exp.Levenshtein: rename_func("EDITDIST3"), exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), @@ -141,11 +152,18 @@ class SQLite(Dialect): exp.TryCast: no_trycast_sql, } + # SQLite doesn't generally support CREATE TABLE .. properties + # https://www.sqlite.org/lang_createtable.html PROPERTIES_LOCATION = { - k: exp.Properties.Location.UNSUPPORTED - for k, v in generator.Generator.PROPERTIES_LOCATION.items() + prop: exp.Properties.Location.UNSUPPORTED + for prop in generator.Generator.PROPERTIES_LOCATION } + # There are a few exceptions (e.g. temporary tables) which are supported or + # can be transpiled to SQLite, so we explicitly override them accordingly + PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA + PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE + LIMIT_FETCH = "LIMIT" def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 2dba1c1..8838f34 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -44,12 +44,14 @@ class StarRocks(MySQL): exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.RegexpLike: rename_func("REGEXP"), - exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToUnix: lambda self, + e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimestampTrunc: lambda self, e: self.func( "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this ), exp.TimeStrToDate: rename_func("TO_DATE"), - exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToStr: lambda self, + e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), } diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 6dbad15..7f9a11a 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -200,7 +200,8 @@ class Teradata(Dialect): exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), - exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", + exp.StrToDate: lambda self, + e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index eddb70a..1bbed67 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -11,9 +11,16 @@ class Trino(Presto): class Generator(Presto.Generator): TRANSFORMS = { **Presto.Generator.TRANSFORMS, - exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.ArraySum: lambda self, + e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.Merge: merge_without_target_sql, } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + class Tokenizer(Presto.Tokenizer): HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a5e04da..70ea97e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -14,7 +14,6 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, parse_date_delta, - path_to_jsonpath, rename_func, timestrtotime_sql, trim_sql, @@ -266,13 +265,32 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: ) -def _parse_len(args: t.List) -> exp.Length: - this = seq_get(args, 0) +def _parse_as_text( + klass: t.Type[exp.Expression], +) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: + def _parse(args: t.List[exp.Expression]) -> exp.Expression: + this = seq_get(args, 0) + + if this and not this.is_string: + this = exp.cast(this, exp.DataType.Type.TEXT) + + expression = seq_get(args, 1) + kwargs = {"this": this} + + if expression: + kwargs["expression"] = expression - if this and not this.is_string: - this = exp.cast(this, exp.DataType.Type.TEXT) + return klass(**kwargs) - return exp.Length(this=this) + return _parse + + +def _json_extract_sql( + self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar +) -> str: + json_query = rename_func("JSON_QUERY")(self, expression) + json_value = rename_func("JSON_VALUE")(self, expression) + return self.func("ISNULL", json_query, json_value) class TSQL(Dialect): @@ -441,8 +459,11 @@ class TSQL(Dialect): "HASHBYTES": _parse_hashbytes, "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, - "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "LEN": _parse_len, + "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract), + "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar), + "LEN": _parse_as_text(exp.Length), + "LEFT": _parse_as_text(exp.Left), + "RIGHT": _parse_as_text(exp.Right), "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, @@ -677,6 +698,7 @@ class TSQL(Dialect): SUPPORTS_SINGLE_ARG_CONCAT = False TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" SUPPORTS_SELECT_INTO = True + JSON_PATH_BRACKETED_KEY_SUPPORTED = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -688,6 +710,12 @@ class TSQL(Dialect): exp.Update, } + SUPPORTED_JSON_PATH_PARTS = { + exp.JSONPathKey, + exp.JSONPathRoot, + exp.JSONPathSubscript, + } + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "BIT", @@ -712,9 +740,10 @@ class TSQL(Dialect): exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, - exp.GetPath: path_to_jsonpath("JSON_VALUE"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.JSONExtract: _json_extract_sql, + exp.JSONExtractScalar: _json_extract_sql, exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), @@ -831,15 +860,21 @@ class TSQL(Dialect): exists = expression.args.pop("exists", None) sql = super().create_sql(expression) + like_property = expression.find(exp.LikeProperty) + if like_property: + ctas_expression = like_property.this + else: + ctas_expression = expression.expression + table = expression.find(exp.Table) # Convert CTAS statement to SELECT .. INTO .. - if kind == "TABLE" and expression.expression: - ctas_with = expression.expression.args.get("with") + if kind == "TABLE" and ctas_expression: + ctas_with = ctas_expression.args.get("with") if ctas_with: ctas_with = ctas_with.pop() - subquery = expression.expression + subquery = ctas_expression if isinstance(subquery, exp.Subqueryable): subquery = subquery.subquery() @@ -847,6 +882,9 @@ class TSQL(Dialect): select_into.set("into", exp.Into(this=table)) select_into.set("with", ctas_with) + if like_property: + select_into.limit(0, copy=False) + sql = self.sql(select_into) if exists: @@ -937,9 +975,19 @@ class TSQL(Dialect): return f"CONSTRAINT {this} {expressions}" def length_sql(self, expression: exp.Length) -> str: + return self._uncast_text(expression, "LEN") + + def right_sql(self, expression: exp.Right) -> str: + return self._uncast_text(expression, "RIGHT") + + def left_sql(self, expression: exp.Left) -> str: + return self._uncast_text(expression, "LEFT") + + def _uncast_text(self, expression: exp.Expression, name: str) -> str: this = expression.this if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT): this_sql = self.sql(this, "this") else: this_sql = self.sql(this) - return self.func("LEN", this_sql) + expression_sql = self.sql(expression, "expression") + return self.func(name, this_sql, expression_sql if expression_sql else None) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index 304981b..c8f9148 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -10,7 +10,6 @@ import logging import time import typing as t -from sqlglot import maybe_parse from sqlglot.errors import ExecuteError from sqlglot.executor.python import PythonExecutor from sqlglot.executor.table import Table, ensure_tables @@ -23,7 +22,6 @@ logger = logging.getLogger("sqlglot") if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType - from sqlglot.executor.table import Tables from sqlglot.expressions import Expression from sqlglot.schema import Schema diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index d7952c1..e4c4040 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -44,9 +44,9 @@ class Context: for other in self.tables.values(): if self._table.columns != other.columns: - raise Exception(f"Columns are different.") + raise Exception("Columns are different.") if len(self._table.rows) != len(other.rows): - raise Exception(f"Rows are different.") + raise Exception("Rows are different.") return self._table diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 6c01edc..218a8e0 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -6,7 +6,7 @@ from functools import wraps from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import PYTHON_VERSION +from sqlglot.helper import PYTHON_VERSION, is_int, seq_get class reverse_key: @@ -143,6 +143,22 @@ def arrayjoin(this, expression, null=None): return expression.join(x for x in (x if x is not None else null for x in this) if x is not None) +@null_if_any("this", "expression") +def jsonextract(this, expression): + for path_segment in expression: + if isinstance(this, dict): + this = this.get(path_segment) + elif isinstance(this, list) and is_int(path_segment): + this = seq_get(this, int(path_segment)) + else: + raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.") + + if this is None: + break + + return this + + ENV = { "exp": exp, # aggs @@ -175,12 +191,12 @@ ENV = { "DOT": null_if_any(lambda e, this: e[this]), "EQ": null_if_any(lambda this, e: this == e), "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), - "GETPATH": null_if_any(lambda this, e: this.get(e)), "GT": null_if_any(lambda this, e: this > e), "GTE": null_if_any(lambda this, e: this >= e), "IF": lambda predicate, true, false: true if predicate else false, "INTDIV": null_if_any(lambda e, this: e // this), "INTERVAL": interval, + "JSONEXTRACT": jsonextract, "LEFT": null_if_any(lambda this, e: this[:e]), "LIKE": null_if_any( lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this)) diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 7ff9608..c0becbe 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -9,7 +9,7 @@ from sqlglot.errors import ExecuteError from sqlglot.executor.context import Context from sqlglot.executor.env import ENV from sqlglot.executor.table import RowReader, Table -from sqlglot.helper import csv_reader, subclasses +from sqlglot.helper import csv_reader, ensure_list, subclasses class PythonExecutor: @@ -368,7 +368,7 @@ def _rename(self, e): if isinstance(e, exp.Func) and e.is_var_len_args: *head, tail = values - return self.func(e.key, *head, *tail) + return self.func(e.key, *head, *ensure_list(tail)) return self.func(e.key, *values) except Exception as ex: @@ -429,18 +429,24 @@ class Python(Dialect): exp.Between: _rename, exp.Boolean: lambda self, e: "True" if e.this else "False", exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", - exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", + exp.Column: lambda self, + e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", exp.Concat: lambda self, e: self.func( "SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions ), exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", exp.Div: _div_sql, - exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", - exp.In: lambda self, e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", + exp.Extract: lambda self, + e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", + exp.In: lambda self, + e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", exp.Is: lambda self, e: ( self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is") ), + exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]", + exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'", + exp.JSONPathSubscript: lambda self, e: f"'{e.this}'", exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a95a73e..3234c99 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -29,6 +29,7 @@ from sqlglot.helper import ( camel_to_snake_case, ensure_collection, ensure_list, + is_int, seq_get, subclasses, ) @@ -175,13 +176,7 @@ class Expression(metaclass=_Expression): """ Checks whether a Literal expression is an integer. """ - if self.is_number: - try: - int(self.name) - return True - except ValueError: - pass - return False + return self.is_number and is_int(self.name) @property def is_star(self) -> bool: @@ -493,8 +488,8 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__): - if not type(node) is self.__class__: + for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__): + if type(node) is not self.__class__: yield node.unnest() if unnest and not isinstance(node, Subquery) else node def __str__(self) -> str: @@ -553,10 +548,12 @@ class Expression(metaclass=_Expression): return new_node @t.overload - def replace(self, expression: E) -> E: ... + def replace(self, expression: E) -> E: + ... @t.overload - def replace(self, expression: None) -> None: ... + def replace(self, expression: None) -> None: + ... def replace(self, expression): """ @@ -610,7 +607,8 @@ class Expression(metaclass=_Expression): >>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql() 'SELECT x, z FROM y' """ - assert isinstance(self, type_) + if not isinstance(self, type_): + raise AssertionError(f"{self} is not {type_}.") return self def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: @@ -1133,6 +1131,7 @@ class SetItem(Expression): class Show(Expression): arg_types = { "this": True, + "history": False, "terse": False, "target": False, "offset": False, @@ -1676,7 +1675,6 @@ class Index(Expression): "amp": False, # teradata "include": False, "partition_by": False, # teradata - "where": False, # postgres partial indexes } @@ -2573,7 +2571,7 @@ class HistoricalData(Expression): class Table(Expression): arg_types = { - "this": True, + "this": False, "alias": False, "db": False, "catalog": False, @@ -3664,6 +3662,7 @@ class DataType(Expression): BINARY = auto() BIT = auto() BOOLEAN = auto() + BPCHAR = auto() CHAR = auto() DATE = auto() DATE32 = auto() @@ -3805,6 +3804,7 @@ class DataType(Expression): dtype: DATA_TYPE, dialect: DialectType = None, udt: bool = False, + copy: bool = True, **kwargs, ) -> DataType: """ @@ -3815,7 +3815,8 @@ class DataType(Expression): dialect: the dialect to use for parsing `dtype`, in case it's a string. udt: when set to True, `dtype` will be used as-is if it can't be parsed into a DataType, thus creating a user-defined type. - kawrgs: additional arguments to pass in the constructor of DataType. + copy: whether or not to copy the data type. + kwargs: additional arguments to pass in the constructor of DataType. Returns: The constructed DataType object. @@ -3837,7 +3838,7 @@ class DataType(Expression): elif isinstance(dtype, DataType.Type): data_type_exp = DataType(this=dtype) elif isinstance(dtype, DataType): - return dtype + return maybe_copy(dtype, copy) else: raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") @@ -3855,7 +3856,7 @@ class DataType(Expression): True, if and only if there is a type in `dtypes` which is equal to this DataType. """ for dtype in dtypes: - other = DataType.build(dtype, udt=True) + other = DataType.build(dtype, copy=False, udt=True) if ( other.expressions @@ -4001,7 +4002,7 @@ class Dot(Binary): def build(self, expressions: t.Sequence[Expression]) -> Dot: """Build a Dot object with a sequence of expressions.""" if len(expressions) < 2: - raise ValueError(f"Dot requires >= 2 expressions.") + raise ValueError("Dot requires >= 2 expressions.") return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions)) @@ -4128,10 +4129,6 @@ class Sub(Binary): pass -class ArrayOverlaps(Binary): - pass - - # Unary Expressions # (NOT a) class Unary(Condition): @@ -4469,6 +4466,10 @@ class ArrayJoin(Func): arg_types = {"this": True, "expression": True, "null": False} +class ArrayOverlaps(Binary, Func): + pass + + class ArraySize(Func): arg_types = {"this": True, "expression": False} @@ -4490,15 +4491,37 @@ class Avg(AggFunc): class AnyValue(AggFunc): - arg_types = {"this": True, "having": False, "max": False, "ignore_nulls": False} + arg_types = {"this": True, "having": False, "max": False} + + +class Lag(AggFunc): + arg_types = {"this": True, "offset": False, "default": False} + + +class Lead(AggFunc): + arg_types = {"this": True, "offset": False, "default": False} + + +# some dialects have a distinction between first and first_value, usually first is an aggregate func +# and first_value is a window func +class First(AggFunc): + pass + + +class Last(AggFunc): + pass + + +class FirstValue(AggFunc): + pass -class First(Func): - arg_types = {"this": True, "ignore_nulls": False} +class LastValue(AggFunc): + pass -class Last(Func): - arg_types = {"this": True, "ignore_nulls": False} +class NthValue(AggFunc): + arg_types = {"this": True, "offset": True} class Case(Func): @@ -4611,7 +4634,7 @@ class CurrentTime(Func): class CurrentTimestamp(Func): - arg_types = {"this": False} + arg_types = {"this": False, "transaction": False} class CurrentUser(Func): @@ -4712,6 +4735,7 @@ class TimestampSub(Func, TimeUnit): class TimestampDiff(Func, TimeUnit): + _sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"] arg_types = {"this": True, "expression": True, "unit": False} @@ -4857,6 +4881,59 @@ class IsInf(Func): _sql_names = ["IS_INF", "ISINF"] +class JSONPath(Expression): + arg_types = {"expressions": True} + + @property + def output_name(self) -> str: + last_segment = self.expressions[-1].this + return last_segment if isinstance(last_segment, str) else "" + + +class JSONPathPart(Expression): + arg_types = {} + + +class JSONPathFilter(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathKey(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathRecursive(JSONPathPart): + arg_types = {"this": False} + + +class JSONPathRoot(JSONPathPart): + pass + + +class JSONPathScript(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathSlice(JSONPathPart): + arg_types = {"start": False, "end": False, "step": False} + + +class JSONPathSelector(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathSubscript(JSONPathPart): + arg_types = {"this": True} + + +class JSONPathUnion(JSONPathPart): + arg_types = {"expressions": True} + + +class JSONPathWildcard(JSONPathPart): + pass + + class FormatJson(Expression): pass @@ -4940,18 +5017,30 @@ class JSONBContains(Binary): class JSONExtract(Binary, Func): + arg_types = {"this": True, "expression": True, "expressions": False} _sql_names = ["JSON_EXTRACT"] + is_var_len_args = True + + @property + def output_name(self) -> str: + return self.expression.output_name if not self.expressions else "" -class JSONExtractScalar(JSONExtract): +class JSONExtractScalar(Binary, Func): + arg_types = {"this": True, "expression": True, "expressions": False} _sql_names = ["JSON_EXTRACT_SCALAR"] + is_var_len_args = True + + @property + def output_name(self) -> str: + return self.expression.output_name -class JSONBExtract(JSONExtract): +class JSONBExtract(Binary, Func): _sql_names = ["JSONB_EXTRACT"] -class JSONBExtractScalar(JSONExtract): +class JSONBExtractScalar(Binary, Func): _sql_names = ["JSONB_EXTRACT_SCALAR"] @@ -4972,15 +5061,6 @@ class ParseJSON(Func): is_var_len_args = True -# https://docs.snowflake.com/en/sql-reference/functions/get_path -class GetPath(Func): - arg_types = {"this": True, "expression": True} - - @property - def output_name(self) -> str: - return self.expression.output_name - - class Least(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -5476,6 +5556,8 @@ def _norm_arg(arg): ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()} +JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,)) + # Helpers @t.overload @@ -5487,7 +5569,8 @@ def maybe_parse( prefix: t.Optional[str] = None, copy: bool = False, **opts, -) -> E: ... +) -> E: + ... @t.overload @@ -5499,7 +5582,8 @@ def maybe_parse( prefix: t.Optional[str] = None, copy: bool = False, **opts, -) -> E: ... +) -> E: + ... def maybe_parse( @@ -5539,7 +5623,7 @@ def maybe_parse( return sql_or_expression if sql_or_expression is None: - raise ParseError(f"SQL cannot be None") + raise ParseError("SQL cannot be None") import sqlglot @@ -5551,11 +5635,13 @@ def maybe_parse( @t.overload -def maybe_copy(instance: None, copy: bool = True) -> None: ... +def maybe_copy(instance: None, copy: bool = True) -> None: + ... @t.overload -def maybe_copy(instance: E, copy: bool = True) -> E: ... +def maybe_copy(instance: E, copy: bool = True) -> E: + ... def maybe_copy(instance, copy=True): @@ -6174,17 +6260,19 @@ def paren(expression: ExpOrStr, copy: bool = True) -> Paren: return Paren(this=maybe_parse(expression, copy=copy)) -SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") +SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$") @t.overload -def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ... +def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: + ... @t.overload def to_identifier( name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True -) -> Identifier: ... +) -> Identifier: + ... def to_identifier(name, quoted=None, copy=True): @@ -6256,11 +6344,13 @@ def to_interval(interval: str | Literal) -> Interval: @t.overload -def to_table(sql_path: str | Table, **kwargs) -> Table: ... +def to_table(sql_path: str | Table, **kwargs) -> Table: + ... @t.overload -def to_table(sql_path: None, **kwargs) -> None: ... +def to_table(sql_path: None, **kwargs) -> None: + ... def to_table( @@ -6460,7 +6550,7 @@ def column( return this -def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast: +def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast: """Cast an expression to a data type. Example: @@ -6470,12 +6560,13 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast: Args: expression: The expression to cast. to: The datatype to cast to. + copy: Whether or not to copy the supplied expressions. Returns: The new Cast instance. """ - expression = maybe_parse(expression, **opts) - data_type = DataType.build(to, **opts) + expression = maybe_parse(expression, copy=copy, **opts) + data_type = DataType.build(to, copy=copy, **opts) expression = Cast(this=expression, to=data_type) expression.type = data_type return expression diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 8e3ff9b..568dcb4 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -9,6 +9,7 @@ from functools import reduce from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv, seq_get +from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -21,7 +22,18 @@ logger = logging.getLogger("sqlglot") ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)") -class Generator: +class _Generator(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # Remove transforms that correspond to unsupported JSONPathPart expressions + for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS: + klass.TRANSFORMS.pop(part, None) + + return klass + + +class Generator(metaclass=_Generator): """ Generator converts a given syntax tree to the corresponding SQL string. @@ -58,19 +70,23 @@ class Generator: Default: True """ - TRANSFORMS = { - exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) - ), - exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", + TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { + **JSON_PATH_PART_TRANSFORMS, + exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", + exp.CaseSpecificColumnConstraint: lambda self, + e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", - exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", + exp.CharacterSetProperty: lambda self, + e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", - exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", + exp.ClusteredColumnConstraint: lambda self, + e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", - exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", - exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS", exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", + exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS", + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) + ), exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", @@ -85,29 +101,33 @@ class Generator: exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.MaterializedProperty: lambda self, e: "MATERIALIZED", + exp.NonClusteredColumnConstraint: lambda self, + e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", - exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION", - exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", + exp.OnCommitProperty: lambda self, + e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", - exp.RemoteWithConnectionModelProperty: lambda self, e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}", + exp.RemoteWithConnectionModelProperty: lambda self, + e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}", exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", - exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", exp.SetConfigProperty: lambda self, e: self.sql(e, "this"), + exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlReadWriteProperty: lambda self, e: e.name, - exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", + exp.SqlSecurityProperty: lambda self, + e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.StabilityProperty: lambda self, e: e.name, - exp.TemporaryProperty: lambda self, e: f"TEMPORARY", + exp.TemporaryProperty: lambda self, e: "TEMPORARY", + exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", - exp.TransientProperty: lambda self, e: "TRANSIENT", exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions), - exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", - exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE", + exp.TransientProperty: lambda self, e: "TRANSIENT", + exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE", exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), exp.VolatileProperty: lambda self, e: "VOLATILE", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", @@ -117,6 +137,10 @@ class Generator: # True: Full Support, None: No support, False: No support in window specifications NULL_ORDERING_SUPPORTED: t.Optional[bool] = True + # Whether or not ignore nulls is inside the agg or outside. + # FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER + IGNORE_NULLS_IN_FUNC = False + # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported LOCKING_READS_SUPPORTED = False @@ -266,6 +290,24 @@ class Generator: # Whether or not UNLOGGED tables can be created SUPPORTS_UNLOGGED_TABLES = False + # Whether or not the CREATE TABLE LIKE statement is supported + SUPPORTS_CREATE_TABLE_LIKE = True + + # Whether or not the LikeProperty needs to be specified inside of the schema clause + LIKE_PROPERTY_INSIDE_SCHEMA = False + + # Whether or not the JSON extraction operators expect a value of type JSON + JSON_TYPE_REQUIRED_FOR_EXTRACTION = False + + # Whether or not bracketed keys like ["foo"] are supported in JSON paths + JSON_PATH_BRACKETED_KEY_SUPPORTED = True + + # Whether or not to escape keys using single quotes in JSON paths + JSON_PATH_SINGLE_QUOTE_ESCAPE = False + + # The JSONPathPart expressions supported by this dialect + SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy() + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -641,8 +683,6 @@ class Generator: if callable(transform): sql = transform(self, expression) - elif transform: - sql = transform elif isinstance(expression, exp.Expression): exp_handler_name = f"{expression.key}_sql" @@ -802,7 +842,7 @@ class Generator: desc = expression.args.get("desc") if desc is not None: return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" - return f"PRIMARY KEY" + return "PRIMARY KEY" def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: this = self.sql(expression, "this") @@ -1218,9 +1258,21 @@ class Generator: return f"{property_name}={self.sql(expression, 'this')}" def likeproperty_sql(self, expression: exp.LikeProperty) -> str: - options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) - options = f" {options}" if options else "" - return f"LIKE {self.sql(expression, 'this')}{options}" + if self.SUPPORTS_CREATE_TABLE_LIKE: + options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) + options = f" {options}" if options else "" + + like = f"LIKE {self.sql(expression, 'this')}{options}" + if self.LIKE_PROPERTY_INSIDE_SCHEMA and not isinstance(expression.parent, exp.Schema): + like = f"({like})" + + return like + + if expression.expressions: + self.unsupported("Transpilation of LIKE property options is unsupported") + + select = exp.select("*").from_(expression.this).limit(0) + return f"AS {self.sql(select)}" def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str: no = "NO " if expression.args.get("no") else "" @@ -2367,6 +2419,31 @@ class Generator: def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}" + def jsonpath_sql(self, expression: exp.JSONPath) -> str: + path = self.expressions(expression, sep="", flat=True).lstrip(".") + return f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" + + def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str: + if isinstance(expression, exp.JSONPathPart): + transform = self.TRANSFORMS.get(expression.__class__) + if not callable(transform): + self.unsupported(f"Unsupported JSONPathPart type {expression.__class__.__name__}") + return "" + + return transform(self, expression) + + if isinstance(expression, int): + return str(expression) + + if self.JSON_PATH_SINGLE_QUOTE_ESCAPE: + escaped = expression.replace("'", "\\'") + escaped = f"\\'{expression}\\'" + else: + escaped = expression.replace('"', '\\"') + escaped = f'"{escaped}"' + + return escaped + def formatjson_sql(self, expression: exp.FormatJson) -> str: return f"{self.sql(expression, 'this')} FORMAT JSON" @@ -2620,6 +2697,9 @@ class Generator: zone = self.sql(expression, "this") return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" + def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str: + return self.func("CURRENT_TIMESTAMP", expression.this) + def collate_sql(self, expression: exp.Collate) -> str: if self.COLLATE_IS_FUNC: return self.function_fallback_sql(expression) @@ -2761,10 +2841,20 @@ class Generator: return f"DISTINCT{this}{on}" def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: - return f"{self.sql(expression, 'this')} IGNORE NULLS" + return self._embed_ignore_nulls(expression, "IGNORE NULLS") def respectnulls_sql(self, expression: exp.RespectNulls) -> str: - return f"{self.sql(expression, 'this')} RESPECT NULLS" + return self._embed_ignore_nulls(expression, "RESPECT NULLS") + + def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str: + if self.IGNORE_NULLS_IN_FUNC: + this = expression.find(exp.AggFunc) + if this: + sql = self.sql(this) + sql = sql[:-1] + f" {text})" + return sql + + return f"{self.sql(expression, 'this')} {text}" def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( @@ -2935,7 +3025,7 @@ class Generator: def format_args(self, *args: t.Optional[str | exp.Expression]) -> str: arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None) if self.pretty and self.text_width(arg_sqls) > self.max_text_width: - return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) + return self.indent("\n" + ",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True) return ", ".join(arg_sqls) def text_width(self, args: t.Iterable) -> int: @@ -3279,6 +3369,22 @@ class Generator: return self.func("LAST_DAY", expression.this) + def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: + this = expression.this + if isinstance(this, exp.JSONPathWildcard): + this = self.json_path_part(this) + return f".{this}" if this else "" + + if exp.SAFE_IDENTIFIER_RE.match(this): + return f".{this}" + + this = self.json_path_part(this) + return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}" + + def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: + this = self.json_path_part(expression.this) + return f"[{this}]" if this else "" + def _simplify_unless_literal(self, expression: E) -> E: if not isinstance(expression, exp.Literal): from sqlglot.optimizer.simplify import simplify diff --git a/sqlglot/helper.py b/sqlglot/helper.py index de737be..9799fe2 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -53,11 +53,13 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: @t.overload -def ensure_list(value: t.Collection[T]) -> t.List[T]: ... +def ensure_list(value: t.Collection[T]) -> t.List[T]: + ... @t.overload -def ensure_list(value: T) -> t.List[T]: ... +def ensure_list(value: T) -> t.List[T]: + ... def ensure_list(value): @@ -79,11 +81,13 @@ def ensure_list(value): @t.overload -def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: + ... @t.overload -def ensure_collection(value: T) -> t.Collection[T]: ... +def ensure_collection(value: T) -> t.Collection[T]: + ... def ensure_collection(value): @@ -232,7 +236,7 @@ def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: for node, deps in tuple(dag.items()): for dep in deps: - if not dep in dag: + if dep not in dag: dag[dep] = set() while dag: @@ -316,6 +320,14 @@ def find_new_name(taken: t.Collection[str], base: str) -> str: return new +def is_int(text: str) -> bool: + try: + int(text) + return True + except ValueError: + return False + + def name_sequence(prefix: str) -> t.Callable[[], str]: """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" sequence = count() diff --git a/sqlglot/jsonpath.py b/sqlglot/jsonpath.py index c410d11..129a4e6 100644 --- a/sqlglot/jsonpath.py +++ b/sqlglot/jsonpath.py @@ -2,8 +2,8 @@ from __future__ import annotations import typing as t +import sqlglot.expressions as exp from sqlglot.errors import ParseError -from sqlglot.expressions import SAFE_IDENTIFIER_RE from sqlglot.tokens import Token, Tokenizer, TokenType if t.TYPE_CHECKING: @@ -36,20 +36,8 @@ class JSONPathTokenizer(Tokenizer): STRING_ESCAPES = ["\\"] -JSONPathNode = t.Dict[str, t.Any] - - -def _node(kind: str, value: t.Any = None, **kwargs: t.Any) -> JSONPathNode: - node = {"kind": kind, **kwargs} - - if value is not None: - node["value"] = value - - return node - - -def parse(path: str) -> t.List[JSONPathNode]: - """Takes in a JSONPath string and converts into a list of nodes.""" +def parse(path: str) -> exp.JSONPath: + """Takes in a JSON path string and parses it into a JSONPath expression.""" tokens = JSONPathTokenizer().tokenize(path) size = len(tokens) @@ -89,7 +77,7 @@ def parse(path: str) -> t.List[JSONPathNode]: if token: return token.text if _match(TokenType.STAR): - return _node("wildcard") + return exp.JSONPathWildcard() if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN): script = _prev().text == "(" start = i @@ -100,9 +88,9 @@ def parse(path: str) -> t.List[JSONPathNode]: if _curr() in (TokenType.R_BRACKET, None): break _advance() - return _node( - "script" if script else "filter", path[tokens[start].start : tokens[i].end] - ) + + expr_type = exp.JSONPathScript if script else exp.JSONPathFilter + return expr_type(this=path[tokens[start].start : tokens[i].end]) number = "-" if _match(TokenType.DASH) else "" @@ -112,6 +100,7 @@ def parse(path: str) -> t.List[JSONPathNode]: if number: return int(number) + return False def _parse_slice() -> t.Any: @@ -121,9 +110,10 @@ def parse(path: str) -> t.List[JSONPathNode]: if end is None and step is None: return start - return _node("slice", start=start, end=end, step=step) - def _parse_bracket() -> JSONPathNode: + return exp.JSONPathSlice(start=start, end=end, step=step) + + def _parse_bracket() -> exp.JSONPathPart: literal = _parse_slice() if isinstance(literal, str) or literal is not False: @@ -136,13 +126,15 @@ def parse(path: str) -> t.List[JSONPathNode]: if len(indexes) == 1: if isinstance(literal, str): - node = _node("key", indexes[0]) - elif isinstance(literal, dict) and literal["kind"] in ("script", "filter"): - node = _node("selector", indexes[0]) + node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0]) + elif isinstance(literal, exp.JSONPathPart) and isinstance( + literal, (exp.JSONPathScript, exp.JSONPathFilter) + ): + node = exp.JSONPathSelector(this=indexes[0]) else: - node = _node("subscript", indexes[0]) + node = exp.JSONPathSubscript(this=indexes[0]) else: - node = _node("union", indexes) + node = exp.JSONPathUnion(expressions=indexes) else: raise ParseError(_error("Cannot have empty segment")) @@ -150,66 +142,56 @@ def parse(path: str) -> t.List[JSONPathNode]: return node - nodes = [] + # We canonicalize the JSON path AST so that it always starts with a + # "root" element, so paths like "field" will be generated as "$.field" + _match(TokenType.DOLLAR) + expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] while _curr(): - if _match(TokenType.DOLLAR): - nodes.append(_node("root")) - elif _match(TokenType.DOT): + if _match(TokenType.DOT) or _match(TokenType.COLON): recursive = _prev().text == ".." - value = _match(TokenType.VAR) or _match(TokenType.STAR) - nodes.append( - _node("recursive" if recursive else "child", value=value.text if value else None) - ) + + if _match(TokenType.VAR) or _match(TokenType.IDENTIFIER): + value: t.Optional[str | exp.JSONPathWildcard] = _prev().text + elif _match(TokenType.STAR): + value = exp.JSONPathWildcard() + else: + value = None + + if recursive: + expressions.append(exp.JSONPathRecursive(this=value)) + elif value: + expressions.append(exp.JSONPathKey(this=value)) + else: + raise ParseError(_error("Expected key name or * after DOT")) elif _match(TokenType.L_BRACKET): - nodes.append(_parse_bracket()) - elif _match(TokenType.VAR): - nodes.append(_node("key", _prev().text)) + expressions.append(_parse_bracket()) + elif _match(TokenType.VAR) or _match(TokenType.IDENTIFIER): + expressions.append(exp.JSONPathKey(this=_prev().text)) elif _match(TokenType.STAR): - nodes.append(_node("wildcard")) - elif _match(TokenType.PARAMETER): - nodes.append(_node("current")) + expressions.append(exp.JSONPathWildcard()) else: raise ParseError(_error(f"Unexpected {tokens[i].token_type}")) - return nodes + return exp.JSONPath(expressions=expressions) -MAPPING = { - "child": lambda n: f".{n['value']}" if n.get("value") is not None else "", - "filter": lambda n: f"?{n['value']}", - "key": lambda n: ( - f".{n['value']}" if SAFE_IDENTIFIER_RE.match(n["value"]) else f'[{generate([n["value"]])}]' - ), - "recursive": lambda n: f"..{n['value']}" if n.get("value") is not None else "..", - "root": lambda _: "$", - "script": lambda n: f"({n['value']}", - "slice": lambda n: ":".join( - "" if p is False else generate([p]) - for p in [n["start"], n["end"], n["step"]] +JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { + exp.JSONPathFilter: lambda _, e: f"?{e.this}", + exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e), + exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}", + exp.JSONPathRoot: lambda *_: "$", + exp.JSONPathScript: lambda _, e: f"({e.this}", + exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]", + exp.JSONPathSlice: lambda self, e: ":".join( + "" if p is False else self.json_path_part(p) + for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")] if p is not None ), - "selector": lambda n: f"[{generate([n['value']])}]", - "subscript": lambda n: f"[{generate([n['value']])}]", - "union": lambda n: f"[{','.join(generate([p]) for p in n['value'])}]", - "wildcard": lambda _: "*", + exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e), + exp.JSONPathUnion: lambda self, + e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]", + exp.JSONPathWildcard: lambda *_: "*", } - -def generate( - nodes: t.List[JSONPathNode], - mapping: t.Optional[t.Dict[str, t.Callable[[JSONPathNode], str]]] = None, -) -> str: - mapping = MAPPING if mapping is None else mapping - path = [] - - for node in nodes: - if isinstance(node, dict): - path.append(mapping[node["kind"]](node)) - elif isinstance(node, str): - escaped = node.replace('"', '\\"') - path.append(f'"{escaped}"') - else: - path.append(str(node)) - - return "".join(path) +ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS) diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py index ee48006..34ea6cb 100644 --- a/sqlglot/optimizer/__init__.py +++ b/sqlglot/optimizer/__init__.py @@ -1,3 +1,5 @@ +# ruff: noqa: F401 + from sqlglot.optimizer.optimizer import RULES, optimize from sqlglot.optimizer.scope import ( Scope, diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index f2a0990..d22a998 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -10,11 +10,13 @@ if t.TYPE_CHECKING: @t.overload -def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ... +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: + ... @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: + ... def normalize_identifiers(expression, dialect=None): diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index e3aaebc..53490bf 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -8,10 +8,10 @@ from sqlglot.schema import ensure_schema # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() + # Selection to use if selection list is empty -DEFAULT_SELECTION = lambda is_agg: alias( - exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_" -) +def default_selection(is_agg: bool) -> exp.Alias: + return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") def pushdown_projections(expression, schema=None, remove_unused_selections=True): @@ -129,7 +129,7 @@ def _remove_unused_selections(scope, parent_selections, schema, alias_count): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION(is_agg)) + new_selections.append(default_selection(is_agg)) scope.expression.select(*new_selections, append=False, copy=False) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d5b9119..90357dd 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -104,7 +104,6 @@ def simplify( if root: expression.replace(node) - return node expression = while_changing(expression, _simplify) @@ -174,16 +173,20 @@ def simplify_not(expression): if isinstance(this, exp.Paren): condition = this.unnest() if isinstance(condition, exp.And): - return exp.or_( - exp.not_(condition.left, copy=False), - exp.not_(condition.right, copy=False), - copy=False, + return exp.paren( + exp.or_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ) ) if isinstance(condition, exp.Or): - return exp.and_( - exp.not_(condition.left, copy=False), - exp.not_(condition.right, copy=False), - copy=False, + return exp.paren( + exp.and_( + exp.not_(condition.left, copy=False), + exp.not_(condition.right, copy=False), + copy=False, + ) ) if is_null(condition): return exp.null() @@ -490,7 +493,7 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if isinstance(expression, COMPARISONS): l, r = expression.left, expression.right - if not l.__class__ in INVERSE_OPS: + if l.__class__ not in INVERSE_OPS: return expression if r.is_number: @@ -714,8 +717,7 @@ def simplify_concat(expression): """Reduces all groups that contain string literals by concatenating them.""" if not isinstance(expression, CONCATS) or ( # We can't reduce a CONCAT_WS call if we don't statically know the separator - isinstance(expression, exp.ConcatWs) - and not expression.expressions[0].is_string + isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string ): return expression diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c091605..a89e4fa 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -60,6 +60,19 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func: return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) +def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: + def _parser(args: t.List, dialect: Dialect) -> E: + expression = expr_type( + this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) + ) + if len(args) > 2 and expr_type is exp.JSONExtract: + expression.set("expressions", args[2:]) + + return expression + + return _parser + + class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -102,6 +115,9 @@ class Parser(metaclass=_Parser): to=exp.DataType(this=exp.DataType.Type.TEXT), ), "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), + "JSON_EXTRACT": parse_extract_json_with_path(exp.JSONExtract), + "JSON_EXTRACT_SCALAR": parse_extract_json_with_path(exp.JSONExtractScalar), + "JSON_EXTRACT_PATH_TEXT": parse_extract_json_with_path(exp.JSONExtractScalar), "LIKE": parse_like, "LOG": parse_logarithm, "TIME_TO_TIME_STR": lambda args: exp.Cast( @@ -175,6 +191,7 @@ class Parser(metaclass=_Parser): TokenType.NCHAR, TokenType.VARCHAR, TokenType.NVARCHAR, + TokenType.BPCHAR, TokenType.TEXT, TokenType.MEDIUMTEXT, TokenType.LONGTEXT, @@ -295,6 +312,7 @@ class Parser(metaclass=_Parser): TokenType.ASC, TokenType.AUTO_INCREMENT, TokenType.BEGIN, + TokenType.BPCHAR, TokenType.CACHE, TokenType.CASE, TokenType.COLLATE, @@ -531,12 +549,12 @@ class Parser(metaclass=_Parser): TokenType.ARROW: lambda self, this, path: self.expression( exp.JSONExtract, this=this, - expression=path, + expression=self.dialect.to_json_path(path), ), TokenType.DARROW: lambda self, this, path: self.expression( exp.JSONExtractScalar, this=this, - expression=path, + expression=self.dialect.to_json_path(path), ), TokenType.HASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtract, @@ -1334,7 +1352,9 @@ class Parser(metaclass=_Parser): exp.Drop, comments=start.comments, exists=exists or self._parse_exists(), - this=self._parse_table(schema=True), + this=self._parse_table( + schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA + ), kind=kind, temporary=temporary, materialized=materialized, @@ -1422,7 +1442,9 @@ class Parser(metaclass=_Parser): elif create_token.token_type == TokenType.INDEX: this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: - table_parts = self._parse_table_parts(schema=True) + table_parts = self._parse_table_parts( + schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA + ) # exp.Properties.Location.POST_NAME self._match(TokenType.COMMA) @@ -2499,11 +2521,11 @@ class Parser(metaclass=_Parser): elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): text = "ALL ROWS PER MATCH" if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): - text += f" SHOW EMPTY MATCHES" + text += " SHOW EMPTY MATCHES" elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"): - text += f" OMIT EMPTY MATCHES" + text += " OMIT EMPTY MATCHES" elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): - text += f" WITH UNMATCHED ROWS" + text += " WITH UNMATCHED ROWS" rows = exp.var(text) else: rows = None @@ -2511,9 +2533,9 @@ class Parser(metaclass=_Parser): if self._match_text_seq("AFTER", "MATCH", "SKIP"): text = "AFTER MATCH SKIP" if self._match_text_seq("PAST", "LAST", "ROW"): - text += f" PAST LAST ROW" + text += " PAST LAST ROW" elif self._match_text_seq("TO", "NEXT", "ROW"): - text += f" TO NEXT ROW" + text += " TO NEXT ROW" elif self._match_text_seq("TO", "FIRST"): text += f" TO FIRST {self._advance_any().text}" # type: ignore elif self._match_text_seq("TO", "LAST"): @@ -2772,7 +2794,7 @@ class Parser(metaclass=_Parser): or self._parse_placeholder() ) - def _parse_table_parts(self, schema: bool = False) -> exp.Table: + def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table: catalog = None db = None table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) @@ -2788,8 +2810,15 @@ class Parser(metaclass=_Parser): db = table table = self._parse_table_part(schema=schema) or "" - if not table: + if is_db_reference: + catalog = db + db = table + table = None + + if not table and not is_db_reference: self.raise_error(f"Expected table name but got {self._curr}") + if not db and is_db_reference: + self.raise_error(f"Expected database name but got {self._curr}") return self.expression( exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots() @@ -2801,6 +2830,7 @@ class Parser(metaclass=_Parser): joins: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, + is_db_reference: bool = False, ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -2823,7 +2853,11 @@ class Parser(metaclass=_Parser): bracket = parse_bracket and self._parse_bracket(None) bracket = self.expression(exp.Table, this=bracket) if bracket else None this = t.cast( - exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema)) + exp.Expression, + bracket + or self._parse_bracket( + self._parse_table_parts(schema=schema, is_db_reference=is_db_reference) + ), ) if schema: @@ -3650,7 +3684,6 @@ class Parser(metaclass=_Parser): identifier = allow_identifiers and self._parse_id_var( any_token=False, tokens=(TokenType.VAR,) ) - if identifier: tokens = self.dialect.tokenize(identifier.name) @@ -3818,12 +3851,14 @@ class Parser(metaclass=_Parser): return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) def _parse_column(self) -> t.Optional[exp.Expression]: + this = self._parse_column_reference() + return self._parse_column_ops(this) if this else self._parse_bracket(this) + + def _parse_column_reference(self) -> t.Optional[exp.Expression]: this = self._parse_field() if isinstance(this, exp.Identifier): this = self.expression(exp.Column, this=this) - elif not this: - return self._parse_bracket(this) - return self._parse_column_ops(this) + return this def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: this = self._parse_bracket(this) @@ -3837,13 +3872,7 @@ class Parser(metaclass=_Parser): if not field: self.raise_error("Expected type") elif op and self._curr: - self._advance() - value = self._prev.text - field = ( - exp.Literal.number(value) - if self._prev.token_type == TokenType.NUMBER - else exp.Literal.string(value) - ) + field = self._parse_column_reference() else: field = self._parse_field(anonymous_func=True, any_token=True) @@ -4375,7 +4404,10 @@ class Parser(metaclass=_Parser): options[kind] = action return self.expression( - exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore + exp.ForeignKey, + expressions=expressions, + reference=reference, + **options, # type: ignore ) def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: @@ -4692,10 +4724,12 @@ class Parser(metaclass=_Parser): return None @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: + ... @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: + ... def _parse_json_object(self, agg=False): star = self._parse_star() @@ -4937,6 +4971,13 @@ class Parser(metaclass=_Parser): # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) # and Snowflake chose to do the same for familiarity # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes + if isinstance(this, exp.AggFunc): + ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls) + + if ignore_respect and ignore_respect is not this: + ignore_respect.replace(ignore_respect.this) + this = self.expression(ignore_respect.__class__, this=this) + this = self._parse_respect_or_ignore_nulls(this) # bigquery select from window x AS (partition by ...) @@ -5732,12 +5773,14 @@ class Parser(metaclass=_Parser): return True @t.overload - def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ... + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: + ... @t.overload def _replace_columns_with_dots( self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: ... + ) -> t.Optional[exp.Expression]: + ... def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8a363d2..87a4924 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -125,6 +125,7 @@ class TokenType(AutoName): NCHAR = auto() VARCHAR = auto() NVARCHAR = auto() + BPCHAR = auto() TEXT = auto() MEDIUMTEXT = auto() LONGTEXT = auto() @@ -801,6 +802,7 @@ class Tokenizer(metaclass=_Tokenizer): "VARCHAR2": TokenType.VARCHAR, "NVARCHAR": TokenType.NVARCHAR, "NVARCHAR2": TokenType.NVARCHAR, + "BPCHAR": TokenType.BPCHAR, "STR": TokenType.TEXT, "STRING": TokenType.TEXT, "TEXT": TokenType.TEXT, @@ -1141,7 +1143,7 @@ class Tokenizer(metaclass=_Tokenizer): self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) self._advance(comment_end_size - 1) else: - while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK: + while not self._end and self.WHITE_SPACE.get(self._peek) is not TokenType.BREAK: self._advance(alnum=True) self._comments.append(self._text[comment_start_size:]) @@ -1259,7 +1261,7 @@ class Tokenizer(metaclass=_Tokenizer): if base: try: int(text, base) - except: + except Exception: raise TokenError( f"Numeric string contains invalid characters from {self._line}:{self._start}" ) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 0da65b5..f13569f 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -485,8 +485,8 @@ def preprocess( expression_type = type(expression) expression = transforms[0](expression) - for t in transforms[1:]: - expression = t(expression) + for transform in transforms[1:]: + expression = transform(expression) _sql_handler = getattr(self, expression.key + "_sql", None) if _sql_handler: |