diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-01-23 05:06:14 +0000 |
commit | 38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch) | |
tree | 64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot | |
parent | Releasing debian version 20.4.0-1. (diff) | |
download | sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip |
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
36 files changed, 1843 insertions, 559 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 6658287..141a302 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -574,13 +574,13 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: return Column.invoke_expression_over_column( - col, expression.DateAdd, expression=days, unit=expression.Var(this="day") + col, expression.DateAdd, expression=days, unit=expression.Var(this="DAY") ) def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: return Column.invoke_expression_over_column( - col, expression.DateSub, expression=days, unit=expression.Var(this="day") + col, expression.DateSub, expression=days, unit=expression.Var(this="DAY") ) @@ -635,7 +635,7 @@ def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: def last_day(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "LAST_DAY") + return Column.invoke_expression_over_column(col, expression.LastDay) def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 7a573e7..0151e6c 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -16,20 +16,22 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, inline_array_sql, - json_keyvalue_comma_sql, max_or_greatest, min_or_least, no_ilike_sql, parse_date_delta_with_interval, + path_to_jsonpath, regexp_replace_sql, rename_func, timestrtotime_sql, ts_or_ds_add_cast, - ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from typing_extensions import Literal + logger = logging.getLogger("sqlglot") @@ -206,12 +208,17 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s return f"TIMESTAMP_MILLIS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"TIMESTAMP_MICROS({timestamp})" - if scale == exp.UnixToTime.NANOS: - # We need to cast to INT64 because that's what BQ expects - return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))" + + +def _parse_time(args: t.List) -> exp.Func: + if len(args) == 1: + return exp.TsOrDsToTime(this=args[0]) + if len(args) == 3: + return exp.TimeFromParts.from_arg_list(args) + + return exp.Anonymous(this="TIME", expressions=args) class BigQuery(Dialect): @@ -329,7 +336,13 @@ class BigQuery(Dialect): "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "DIV": binary_from_function(exp.IntDiv), + "FORMAT_DATE": lambda args: exp.TimeToStr( + this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0) + ), "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, + "JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar( + this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$") + ), "MD5": exp.MD5Digest.from_arg_list, "TO_HEX": _parse_to_hex, "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( @@ -351,6 +364,7 @@ class BigQuery(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(","), ), + "TIME": _parse_time, "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), @@ -361,9 +375,7 @@ class BigQuery(Dialect): "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), - "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime( - this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS - ), + "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)), "TO_JSON_STRING": exp.JSONFormat.from_arg_list, } @@ -460,7 +472,15 @@ class BigQuery(Dialect): return table - def _parse_json_object(self) -> exp.JSONObject: + @t.overload + def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: + ... + + @t.overload + def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: + ... + + def _parse_json_object(self, agg=False): json_object = super()._parse_json_object() array_kv_pair = seq_get(json_object.expressions, 0) @@ -513,6 +533,10 @@ class BigQuery(Dialect): UNNEST_WITH_ORDINALITY = False COLLATE_IS_FUNC = True LIMIT_ONLY_LITERALS = True + SUPPORTS_TABLE_ALIAS_COLUMNS = False + UNPIVOT_ALIASES_ARE_IDENTIFIERS = False + JSON_KEY_VALUE_PAIR_SEP = "," + NULL_ORDERING_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -525,6 +549,7 @@ class BigQuery(Dialect): exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}" if e.args.get("default") else f"COLLATE {self.sql(e, 'this')}", + exp.CountIf: rename_func("COUNTIF"), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), @@ -536,13 +561,13 @@ class BigQuery(Dialect): exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), + exp.GetPath: path_to_jsonpath(), exp.GroupConcat: rename_func("STRING_AGG"), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql(false_value="NULL"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), - exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), @@ -578,16 +603,17 @@ class BigQuery(Dialect): "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") ), exp.TimeAdd: date_add_interval_sql("TIME", "ADD"), + exp.TimeFromParts: rename_func("TIME"), exp.TimeSub: date_add_interval_sql("TIME", "SUB"), exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: _ts_or_ds_diff_sql, - exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), + exp.TsOrDsToTime: rename_func("TIME"), exp.Unhex: rename_func("FROM_HEX"), + exp.UnixDate: rename_func("UNIX_DATE"), exp.UnixToTime: _unix_to_time_sql, exp.Values: _derived_table_values_to_unnest, exp.VariancePop: rename_func("VAR_POP"), @@ -724,6 +750,26 @@ class BigQuery(Dialect): "within", } + def timetostr_sql(self, expression: exp.TimeToStr) -> str: + if isinstance(expression.this, exp.TsOrDsToDate): + this: exp.Expression = expression.this + else: + this = expression + + return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})" + + def struct_sql(self, expression: exp.Struct) -> str: + args = [] + for expr in expression.expressions: + if isinstance(expr, self.KEY_VALUE_DEFINITIONS): + arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}" + else: + arg = self.sql(expr) + + args.append(arg) + + return self.func("STRUCT", *args) + def eq_sql(self, expression: exp.EQ) -> str: # Operands of = cannot be NULL in BigQuery if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): @@ -760,7 +806,20 @@ class BigQuery(Dialect): return inline_array_sql(self, expression) def bracket_sql(self, expression: exp.Bracket) -> str: + this = self.sql(expression, "this") expressions = expression.expressions + + if len(expressions) == 1: + arg = expressions[0] + if arg.type is None: + from sqlglot.optimizer.annotate_types import annotate_types + + arg = annotate_types(arg) + + if arg.type and arg.type.this in exp.DataType.TEXT_TYPES: + # BQ doesn't support bracket syntax with string values + return f"{this}.{arg.name}" + expressions_sql = ", ".join(self.sql(e) for e in expressions) offset = expression.args.get("offset") @@ -768,13 +827,13 @@ class BigQuery(Dialect): expressions_sql = f"OFFSET({expressions_sql})" elif offset == 1: expressions_sql = f"ORDINAL({expressions_sql})" - else: + elif offset is not None: self.unsupported(f"Unsupported array offset: {offset}") if expression.args.get("safe"): expressions_sql = f"SAFE_{expressions_sql}" - return f"{self.sql(expression, 'this')}[{expressions_sql}]" + return f"{this}[{expressions_sql}]" def transaction_sql(self, *_) -> str: return "BEGIN TRANSACTION" diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 870f402..f2e4fe1 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arg_max_or_min_no_count, + date_delta_sql, inline_array_sql, no_pivot_sql, rename_func, @@ -22,16 +23,25 @@ def _lower_func(sql: str) -> str: return sql[:index].lower() + sql[index:] -def _quantile_sql(self, e): +def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str: quantile = e.args["quantile"] args = f"({self.sql(e, 'this')})" + if isinstance(quantile, exp.Array): func = self.func("quantiles", *quantile) else: func = self.func("quantile", quantile) + return func + args +def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc: + if len(args) == 1: + return exp.CountIf(this=seq_get(args, 0)) + + return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If")) + + class ClickHouse(Dialect): NORMALIZE_FUNCTIONS: bool | str = False NULL_ORDERING = "nulls_are_last" @@ -53,6 +63,7 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ATTACH": TokenType.COMMAND, + "DATE32": TokenType.DATE32, "DATETIME64": TokenType.DATETIME64, "DICTIONARY": TokenType.DICTIONARY, "ENUM": TokenType.ENUM, @@ -75,6 +86,8 @@ class ClickHouse(Dialect): "UINT32": TokenType.UINT, "UINT64": TokenType.UBIGINT, "UINT8": TokenType.UTINYINT, + "IPV4": TokenType.IPV4, + "IPV6": TokenType.IPV6, } SINGLE_TOKENS = { @@ -91,6 +104,8 @@ class ClickHouse(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, + "ARRAYSUM": exp.ArraySum.from_arg_list, + "COUNTIF": _parse_count_if, "DATE_ADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -110,6 +125,138 @@ class ClickHouse(Dialect): "XOR": lambda args: exp.Xor(expressions=args), } + AGG_FUNCTIONS = { + "count", + "min", + "max", + "sum", + "avg", + "any", + "stddevPop", + "stddevSamp", + "varPop", + "varSamp", + "corr", + "covarPop", + "covarSamp", + "entropy", + "exponentialMovingAverage", + "intervalLengthSum", + "kolmogorovSmirnovTest", + "mannWhitneyUTest", + "median", + "rankCorr", + "sumKahan", + "studentTTest", + "welchTTest", + "anyHeavy", + "anyLast", + "boundingRatio", + "first_value", + "last_value", + "argMin", + "argMax", + "avgWeighted", + "topK", + "topKWeighted", + "deltaSum", + "deltaSumTimestamp", + "groupArray", + "groupArrayLast", + "groupUniqArray", + "groupArrayInsertAt", + "groupArrayMovingAvg", + "groupArrayMovingSum", + "groupArraySample", + "groupBitAnd", + "groupBitOr", + "groupBitXor", + "groupBitmap", + "groupBitmapAnd", + "groupBitmapOr", + "groupBitmapXor", + "sumWithOverflow", + "sumMap", + "minMap", + "maxMap", + "skewSamp", + "skewPop", + "kurtSamp", + "kurtPop", + "uniq", + "uniqExact", + "uniqCombined", + "uniqCombined64", + "uniqHLL12", + "uniqTheta", + "quantile", + "quantiles", + "quantileExact", + "quantilesExact", + "quantileExactLow", + "quantilesExactLow", + "quantileExactHigh", + "quantilesExactHigh", + "quantileExactWeighted", + "quantilesExactWeighted", + "quantileTiming", + "quantilesTiming", + "quantileTimingWeighted", + "quantilesTimingWeighted", + "quantileDeterministic", + "quantilesDeterministic", + "quantileTDigest", + "quantilesTDigest", + "quantileTDigestWeighted", + "quantilesTDigestWeighted", + "quantileBFloat16", + "quantilesBFloat16", + "quantileBFloat16Weighted", + "quantilesBFloat16Weighted", + "simpleLinearRegression", + "stochasticLinearRegression", + "stochasticLogisticRegression", + "categoricalInformationValue", + "contingency", + "cramersV", + "cramersVBiasCorrected", + "theilsU", + "maxIntersections", + "maxIntersectionsPosition", + "meanZTest", + "quantileInterpolatedWeighted", + "quantilesInterpolatedWeighted", + "quantileGK", + "quantilesGK", + "sparkBar", + "sumCount", + "largestTriangleThreeBuckets", + } + + AGG_FUNCTIONS_SUFFIXES = [ + "If", + "Array", + "ArrayIf", + "Map", + "SimpleState", + "State", + "Merge", + "MergeState", + "ForEach", + "Distinct", + "OrDefault", + "OrNull", + "Resample", + "ArgMin", + "ArgMax", + ] + + AGG_FUNC_MAPPING = ( + lambda functions, suffixes: { + f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions + } + )(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES) + FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} FUNCTION_PARSERS = { @@ -272,9 +419,18 @@ class ClickHouse(Dialect): ) if isinstance(func, exp.Anonymous): + parts = self.AGG_FUNC_MAPPING.get(func.this) params = self._parse_func_params(func) if params: + if parts and parts[1]: + return self.expression( + exp.CombinedParameterizedAgg, + this=func.this, + expressions=func.expressions, + params=params, + parts=parts, + ) return self.expression( exp.ParameterizedAgg, this=func.this, @@ -282,6 +438,20 @@ class ClickHouse(Dialect): params=params, ) + if parts: + if parts[1]: + return self.expression( + exp.CombinedAggFunc, + this=func.this, + expressions=func.expressions, + parts=parts, + ) + return self.expression( + exp.AnonymousAggFunc, + this=func.this, + expressions=func.expressions, + ) + return func def _parse_func_params( @@ -329,6 +499,9 @@ class ClickHouse(Dialect): STRUCT_DELIMITER = ("(", ")") NVL2_SUPPORTED = False TABLESAMPLE_REQUIRES_PARENS = False + TABLESAMPLE_SIZE_IS_ROWS = False + TABLESAMPLE_KEYWORDS = "SAMPLE" + LAST_DAY_SUPPORTS_DATE_PART = False STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -348,6 +521,7 @@ class ClickHouse(Dialect): **STRING_TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.DATE32: "Date32", exp.DataType.Type.DATETIME64: "DateTime64", exp.DataType.Type.DOUBLE: "Float64", exp.DataType.Type.ENUM: "Enum", @@ -372,24 +546,23 @@ class ClickHouse(Dialect): exp.DataType.Type.UINT256: "UInt256", exp.DataType.Type.USMALLINT: "UInt16", exp.DataType.Type.UTINYINT: "UInt8", + exp.DataType.Type.IPV4: "IPv4", + exp.DataType.Type.IPV6: "IPv6", } TRANSFORMS = { **generator.Generator.TRANSFORMS, - exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.AnyValue: rename_func("any"), exp.ApproxDistinct: rename_func("uniq"), + exp.ArraySum: rename_func("arraySum"), exp.ArgMax: arg_max_or_min_no_count("argMax"), exp.ArgMin: arg_max_or_min_no_count("argMin"), exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), + exp.CountIf: rename_func("countIf"), exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), - exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this - ), - exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this - ), + exp.DateAdd: date_delta_sql("DATE_ADD"), + exp.DateDiff: date_delta_sql("DATE_DIFF"), exp.Explode: rename_func("arrayJoin"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.IsNan: rename_func("isNaN"), @@ -400,6 +573,7 @@ class ClickHouse(Dialect): exp.Quantile: _quantile_sql, exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.Rand: rename_func("randCanonical"), + exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.StartsWith: rename_func("startsWith"), exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), @@ -485,10 +659,19 @@ class ClickHouse(Dialect): else "", ] - def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: + def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str: params = self.expressions(expression, key="params", flat=True) return self.func(expression.name, *expression.expressions) + f"({params})" + def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str: + return self.func(expression.name, *expression.expressions) + + def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str: + return self.anonymousaggfunc_sql(expression) + + def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str: + return self.parameterizedagg_sql(expression) + def placeholder_sql(self, expression: exp.Placeholder) -> str: return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 1c10a8b..8e55b6a 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -30,6 +30,8 @@ class Databricks(Spark): } class Generator(Spark.Generator): + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + TRANSFORMS = { **Spark.Generator.TRANSFORMS, exp.DateAdd: date_delta_sql("DATEADD"), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b7eef45..7664c40 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect): ALIAS_POST_TABLESAMPLE = False """Determines whether or not the table alias comes after tablesample.""" + TABLESAMPLE_SIZE_IS_PERCENT = False + """Determines whether or not a size in the table sample clause represents percentage.""" + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE """Specifies the strategy according to which identifiers should be normalized.""" @@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect): For example, such columns may be excluded from `SELECT *` queries. """ + PREFER_CTE_ALIAS_COLUMN = False + """ + Some dialects, such as Snowflake, allow you to reference a CTE column alias in the + HAVING clause of the CTE. This flag will cause the CTE alias columns to override + any projection aliases in the subquery. + + For example, + WITH y(c) AS ( + SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 + ) SELECT c FROM y; + + will be rewritten as + + WITH y(c) AS ( + SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 + ) SELECT c FROM y; + """ + # --- Autofilled --- tokenizer_class = Tokenizer @@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect): result = cls.get(dialect_name.strip()) if not result: - raise ValueError(f"Unknown dialect '{dialect_name}'.") + from difflib import get_close_matches + + similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" + if similar: + similar = f" Did you mean {similar}?" + + raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") return result(**kwargs) @@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: n = self.sql(expression, "this") d = self.sql(expression, "expression") - return f"IF({d} <> 0, {n} / {d}, NULL)" + return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: @@ -695,7 +722,7 @@ def date_add_interval_sql( def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: return self.func( - "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this + "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this ) @@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: return self.func("STRPTIME", expression.this, self.format_time(expression)) -def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: - def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: - _dialect = Dialect.get_or_raise(dialect) - time_format = self.format_time(expression) - if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): - return self.sql( - exp.cast( - exp.StrToTime(this=expression.this, format=expression.args["format"]), - "date", - ) - ) - return self.sql(exp.cast(expression.this, "date")) - - return _ts_or_ds_to_date_sql - - def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) @@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" -# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon -def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: - return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" - - def is_parse_json(expression: exp.Expression) -> bool: return isinstance(expression, exp.ParseJSON) or ( isinstance(expression, exp.Cast) and expression.is_type("json") @@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE expression = ts_or_ds_add_cast(expression) return self.func( - name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this + name, + exp.var(expression.text("unit").upper() or "DAY"), + expression.expression, + expression.this, ) return _delta_sql + + +def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: + from sqlglot.optimizer.simplify import simplify + + # Makes sure the path will be evaluated correctly at runtime to include the path root. + # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. + path = expression.expression + path = exp.func( + "if", + exp.func("startswith", path, "'['"), + exp.func("concat", "'$'", path), + exp.func("concat", "'$.'", path), + ) + + expression.expression.replace(simplify(path)) + return expression + + +def path_to_jsonpath( + name: str = "JSON_EXTRACT", +) -> t.Callable[[Generator, exp.GetPath], str]: + def _transform(self: Generator, expression: exp.GetPath) -> str: + return rename_func(name)(self, prepend_dollar_to_path(expression)) + + return _transform + + +def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: + trunc_curr_date = exp.func("date_trunc", "month", expression.this) + plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") + minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") + + return self.sql(exp.cast(minus_one_day, "date")) + + +def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: + """Remove table refs from columns in when statements.""" + alias = expression.this.args.get("alias") + + normalize = ( + lambda identifier: self.dialect.normalize_identifier(identifier).name + if identifier + else None + ) + + targets = {normalize(expression.this.this)} + + if alias: + targets.add(normalize(alias.this)) + + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node, + copy=False, + ) + + return self.merge_sql(expression) diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 11af17b..6e229b3 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -22,6 +22,7 @@ class Doris(MySQL): "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, "DATE_TRUNC": parse_timestamp_trunc, "REGEXP": exp.RegexpLike.from_arg_list, + "TO_DATE": exp.TsOrDsToDate.from_arg_list, } class Generator(MySQL.Generator): @@ -34,21 +35,26 @@ class Doris(MySQL): exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } + LAST_DAY_SUPPORTS_DATE_PART = False + TIMESTAMP_FUNC_TYPES = set() TRANSFORMS = { **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArgMax: rename_func("MAX_BY"), + exp.ArgMin: rename_func("MIN_BY"), exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.CurrentTimestamp: lambda *_: "NOW()", exp.DateTrunc: lambda self, e: self.func( "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" ), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, + exp.Map: rename_func("ARRAY_MAP"), exp.RegexpLike: rename_func("REGEXP"), exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), - exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Split: rename_func("SPLIT_BY_STRING"), exp.TimeStrToDate: rename_func("TO_DATE"), @@ -63,5 +69,4 @@ class Doris(MySQL): "FROM_UNIXTIME", e.this, time_format("doris")(self, e) ), exp.UnixToTime: rename_func("FROM_UNIXTIME"), - exp.Map: rename_func("ARRAY_MAP"), } diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index c9b31a0..6bca9e7 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -12,7 +12,6 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, ) @@ -99,6 +98,7 @@ class Drill(Dialect): TABLE_HINTS = False QUERY_HINTS = False NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -150,7 +150,6 @@ class Drill(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", - exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index cd9d529..2343b35 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -22,15 +22,15 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, no_timestamp_sql, pivot_column_names, + prepend_dollar_to_path, regexp_extract_sql, rename_func, str_position_sql, str_to_time_sql, timestamptrunc_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, ) -from sqlglot.helper import seq_get +from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -141,11 +141,25 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str return f"EPOCH_MS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"MAKE_TIMESTAMP({timestamp})" - if scale == exp.UnixToTime.NANOS: - return f"TO_TIMESTAMP({timestamp} / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))" + + +def _rename_unless_within_group( + a: str, b: str +) -> t.Callable[[DuckDB.Generator, exp.Expression], str]: + return ( + lambda self, expression: self.func(a, *flatten(expression.args.values())) + if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup) + else self.func(b, *flatten(expression.args.values())) + ) + + +def _parse_struct_pack(args: t.List) -> exp.Struct: + args_with_columns_as_identifiers = [ + exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args + ] + return exp.Struct.from_arg_list(args_with_columns_as_identifiers) class DuckDB(Dialect): @@ -183,6 +197,11 @@ class DuckDB(Dialect): "TIMESTAMP_US": TokenType.TIMESTAMP, } + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } + class Parser(parser.Parser): BITWISE = { **parser.Parser.BITWISE, @@ -209,10 +228,12 @@ class DuckDB(Dialect): "EPOCH_MS": lambda args: exp.UnixToTime( this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), + "JSON": exp.ParseJSON.from_arg_list, "LIST_HAS": exp.ArrayContains.from_arg_list, "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "MAKE_TIME": exp.TimeFromParts.from_arg_list, "MAKE_TIMESTAMP": _parse_make_timestamp, "MEDIAN": lambda args: exp.PercentileCont( this=seq_get(args, 0), expression=exp.Literal.number(0.5) @@ -234,7 +255,7 @@ class DuckDB(Dialect): "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_TO_ARRAY": exp.Split.from_arg_list, "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), - "STRUCT_PACK": exp.Struct.from_arg_list, + "STRUCT_PACK": _parse_struct_pack, "STR_SPLIT": exp.Split.from_arg_list, "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, @@ -250,6 +271,13 @@ class DuckDB(Dialect): TokenType.ANTI, } + PLACEHOLDER_PARSERS = { + **parser.Parser.PLACEHOLDER_PARSERS, + TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) + else None, + } + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -268,7 +296,7 @@ class DuckDB(Dialect): return this - def _parse_struct_types(self) -> t.Optional[exp.Expression]: + def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: return self._parse_field_def() def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: @@ -285,6 +313,10 @@ class DuckDB(Dialect): RENAME_TABLE_WITH_DB = False NVL2_SUPPORTED = False SEMI_ANTI_JOIN_WITH_SIDE = False + TABLESAMPLE_KEYWORDS = "USING SAMPLE" + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + LAST_DAY_SUPPORTS_DATE_PART = False + JSON_KEY_VALUE_PAIR_SEP = "," TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -311,7 +343,7 @@ class DuckDB(Dialect): exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateSub: _date_delta_sql, exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this + "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", @@ -322,11 +354,11 @@ class DuckDB(Dialect): exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.IsInf: rename_func("ISINF"), exp.IsNan: rename_func("ISNAN"), + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONFormat: _json_format_sql, - exp.JSONBExtract: arrow_json_extract_sql, - exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.MonthsBetween: lambda self, e: self.func( @@ -336,8 +368,8 @@ class DuckDB(Dialect): exp.cast(e.this, "timestamp", copy=True), ), exp.ParseJSON: rename_func("JSON"), - exp.PercentileCont: rename_func("QUANTILE_CONT"), - exp.PercentileDisc: rename_func("QUANTILE_DISC"), + exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"), + exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"), # DuckDB doesn't allow qualified columns inside of PIVOT expressions. # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62 exp.Pivot: transforms.preprocess([transforms.unqualify_columns]), @@ -362,7 +394,9 @@ class DuckDB(Dialect): exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, exp.Timestamp: no_timestamp_sql, - exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), + exp.TimestampDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this + ), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, @@ -373,11 +407,10 @@ class DuckDB(Dialect): exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: lambda self, e: self.func( "DATE_DIFF", - f"'{e.args.get('unit') or 'day'}'", + f"'{e.args.get('unit') or 'DAY'}'", exp.cast(e.expression, "TIMESTAMP"), exp.cast(e.this, "TIMESTAMP"), ), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"), exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", @@ -410,6 +443,49 @@ class DuckDB(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: + nano = expression.args.get("nano") + if nano is not None: + expression.set( + "sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0) + ) + + return rename_func("MAKE_TIME")(self, expression) + + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + sec = expression.args["sec"] + + milli = expression.args.get("milli") + if milli is not None: + sec += milli.pop() / exp.Literal.number(1000.0) + + nano = expression.args.get("nano") + if nano is not None: + sec += nano.pop() / exp.Literal.number(1000000000.0) + + if milli or nano: + expression.set("sec", sec) + + return rename_func("MAKE_TIMESTAMP")(self, expression) + + def tablesample_sql( + self, + expression: exp.TableSample, + sep: str = " AS ", + tablesample_keyword: t.Optional[str] = None, + ) -> str: + if not isinstance(expression.parent, exp.Select): + # This sample clause only applies to a single source, not the entire resulting relation + tablesample_keyword = "TABLESAMPLE" + + return super().tablesample_sql( + expression, sep=sep, tablesample_keyword=tablesample_keyword + ) + + def getpath_sql(self, expression: exp.GetPath) -> str: + expression = prepend_dollar_to_path(expression) + return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}" + def interval_sql(self, expression: exp.Interval) -> str: multiplier: t.Optional[int] = None unit = expression.text("unit").lower() @@ -420,11 +496,14 @@ class DuckDB(Dialect): multiplier = 90 if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})" return super().interval_sql(expression) - def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS " - ) -> str: - return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep) + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: + if isinstance(expression.parent, exp.UserDefinedFunction): + return self.sql(expression, "this") + return super().columndef_sql(expression, sep) + + def placeholder_sql(self, expression: exp.Placeholder) -> str: + return f"${expression.name}" if expression.name else "?" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 65c85bb..dffa41e 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -418,13 +418,13 @@ class Hive(Dialect): class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False - TABLESAMPLE_SIZE_IS_PERCENT = True JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Insert, @@ -523,7 +523,6 @@ class Hive(Dialect): exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), - exp.LastDateOfMonth: rename_func("LAST_DAY"), exp.National: lambda self, e: self.national_sql(e, prefix=""), exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5fe3d82..21a9657 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, isnull_to_is_null, - json_keyvalue_comma_sql, locate_to_strposition, max_or_greatest, min_or_least, @@ -21,6 +20,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, parse_date_delta_with_interval, + path_to_jsonpath, rename_func, strposition_to_locate_sql, ) @@ -37,21 +37,21 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") - unit = expression.text("unit") + unit = expression.text("unit").upper() - if unit == "day": + if unit == "DAY": return f"DATE({expr})" - if unit == "week": + if unit == "WEEK": concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" - elif unit == "month": + elif unit == "MONTH": concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" date_format = "%Y %c %e" - elif unit == "quarter": + elif unit == "QUARTER": concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" date_format = "%Y %c %e" - elif unit == "year": + elif unit == "YEAR": concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: @@ -292,9 +292,15 @@ class MySQL(Dialect): "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, + "MAKETIME": exp.TimeFromParts.from_arg_list, + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "MONTHNAME": lambda args: exp.TimeToStr( this=exp.TsOrDsToDate(this=seq_get(args, 0)), format=exp.Literal.string("%B"), @@ -308,11 +314,6 @@ class MySQL(Dialect): ) + 1 ), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "WEEK": lambda args: exp.Week( this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1) ), @@ -441,6 +442,7 @@ class MySQL(Dialect): } LOG_DEFAULTS_TO_LN = True + STRING_ALIASES = True def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() @@ -620,13 +622,15 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True - NULL_ORDERING_SUPPORTED = False + NULL_ORDERING_SUPPORTED = None JOIN_HINTS = False TABLE_HINTS = True DUPLICATE_KEY_UPDATE_WITH_SET = False QUERY_HINT_SEP = " " VALUES_AS_TABLE = False NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False + JSON_KEY_VALUE_PAIR_SEP = "," TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -642,15 +646,16 @@ class MySQL(Dialect): exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), + exp.GetPath: path_to_jsonpath(), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Month: _remove_ts_or_ds_to_date(), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}", + exp.ParseJSON: lambda self, e: self.sql(e, "this"), exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( [ @@ -665,6 +670,7 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.Stuff: rename_func("INSERT"), exp.TableSample: no_tablesample_sql, + exp.TimeFromParts: rename_func("MAKETIME"), exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 51dbd53..6ad3718 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -53,6 +53,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar: class Oracle(Dialect): ALIAS_POST_TABLESAMPLE = True LOCKING_READS_SUPPORTED = True + TABLESAMPLE_SIZE_IS_PERCENT = True # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -81,6 +82,7 @@ class Oracle(Dialect): "WW": "%W", # Week of year (1-53) "YY": "%y", # 15 "YYYY": "%Y", # 2015 + "FF6": "%f", # only 6 digits are supported in python formats } class Parser(parser.Parser): @@ -91,6 +93,8 @@ class Oracle(Dialect): **parser.Parser.FUNCTIONS, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_CHAR": to_char, + "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"), + "TO_DATE": format_time_lambda(exp.StrToDate, "oracle"), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -107,6 +111,11 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()), + } + TYPE_LITERAL_PARSERS = { exp.DataType.Type.DATE: lambda self, this, _: self.expression( exp.DateStrToDate, this=this @@ -153,8 +162,10 @@ class Oracle(Dialect): COLUMN_JOIN_MARKS_SUPPORTED = True DATA_TYPE_SPECIFIERS_ALLOWED = True ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False - LIMIT_FETCH = "FETCH" + TABLESAMPLE_KEYWORDS = "SAMPLE" + LAST_DAY_SUPPORTS_DATE_PART = False + SUPPORTS_SELECT_INTO = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -186,6 +197,7 @@ class Oracle(Dialect): ] ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), @@ -201,6 +213,10 @@ class Oracle(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str: + this = expression.this + return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP" + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" @@ -233,8 +249,10 @@ class Oracle(Dialect): "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NVARCHAR2": TokenType.NVARCHAR, + "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY, "SAMPLE": TokenType.TABLE_SAMPLE, "START": TokenType.BEGIN, + "SYSDATE": TokenType.CURRENT_TIMESTAMP, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index e274877..1ca0a78 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -13,11 +13,12 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, max_or_greatest, + merge_without_target_sql, min_or_least, + no_last_day_sql, no_map_from_entries_sql, no_paren_current_date_sql, no_pivot_sql, - no_tablesample_sql, no_trycast_sql, parse_timestamp_trunc, rename_func, @@ -27,7 +28,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, trim_sql, ts_or_ds_add_cast, - ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser @@ -188,36 +188,6 @@ def _to_timestamp(args: t.List) -> exp.Expression: return format_time_lambda(exp.StrToTime, "postgres")(args) -def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str: - def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: - """Remove table refs from columns in when statements.""" - if isinstance(expression, exp.Merge): - alias = expression.this.args.get("alias") - - normalize = ( - lambda identifier: self.dialect.normalize_identifier(identifier).name - if identifier - else None - ) - - targets = {normalize(expression.this.this)} - - if alias: - targets.add(normalize(alias.this)) - - for when in expression.expressions: - when.transform( - lambda node: exp.column(node.this) - if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets - else node, - copy=False, - ) - - return expression - - return transforms.preprocess([_remove_target_from_merge])(self, expression) - - class Postgres(Dialect): INDEX_OFFSET = 1 TYPED_DIVISION = True @@ -316,6 +286,8 @@ class Postgres(Dialect): **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, "GENERATE_SERIES": _generate_series, + "MAKE_TIME": exp.TimeFromParts.from_arg_list, + "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list, "NOW": exp.CurrentTimestamp.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_TIMESTAMP": _to_timestamp, @@ -387,12 +359,18 @@ class Postgres(Dialect): class Generator(generator.Generator): SINGLE_STRING_INTERVAL = True + RENAME_TABLE_WITH_DB = False LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False NVL2_SUPPORTED = False PARAMETER_TOKEN = "$" + TABLESAMPLE_SIZE_IS_ROWS = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True + # https://www.postgresql.org/docs/current/sql-createtable.html + SUPPORTS_UNLOGGED_TABLES = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -430,12 +408,13 @@ class Postgres(Dialect): exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), + exp.LastDay: no_last_day_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Max: max_or_greatest, exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, - exp.Merge: _merge_sql, + exp.Merge: merge_without_target_sql, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess( [transforms.add_within_group_for_percentiles] @@ -458,16 +437,16 @@ class Postgres(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, + exp.TimeFromParts: rename_func("MAKE_TIME"), + exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TableSample: no_tablesample_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: _date_add_sql("+"), exp.TsOrDsDiff: _date_diff_sql, - exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), exp.Variance: rename_func("VAR_SAMP"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 360ab65..9b421e7 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_safe_divide_sql, no_timestamp_sql, + path_to_jsonpath, regexp_extract_sql, rename_func, right_to_substring_sql, @@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: expression = ts_or_ds_add_cast(expression) - unit = exp.Literal.string(expression.text("unit") or "day") + unit = exp.Literal.string(expression.text("unit") or "DAY") return self.func("DATE_ADD", unit, expression.expression, expression.this) def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: this = exp.cast(expression.this, "TIMESTAMP") expr = exp.cast(expression.expression, "TIMESTAMP") - unit = exp.Literal.string(expression.text("unit") or "day") + unit = exp.Literal.string(expression.text("unit") or "DAY") return self.func("DATE_DIFF", unit, expr, this) @@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) -def _parse_element_at(args: t.List) -> exp.Bracket: - this = seq_get(args, 0) - index = seq_get(args, 1) - assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) - return exp.Bracket(this=this, expressions=[index], offset=1, safe=True) - - def _unnest_sequence(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Table): if isinstance(expression.this, exp.GenerateSeries): @@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str timestamp = self.sql(expression, "this") if scale in (None, exp.UnixToTime.SECONDS): return rename_func("FROM_UNIXTIME")(self, expression) - if scale == exp.UnixToTime.MILLIS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)" - if scale == exp.UnixToTime.MICROS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)" - if scale == exp.UnixToTime.NANOS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))" def _to_int(expression: exp.Expression) -> exp.Expression: @@ -215,6 +202,7 @@ class Presto(Dialect): STRICT_STRING_CONCAT = True SUPPORTS_SEMI_ANTI_JOIN = False TYPED_DIVISION = True + TABLESAMPLE_SIZE_IS_PERCENT = True # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -258,7 +246,9 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, - "ELEMENT_AT": _parse_element_at, + "ELEMENT_AT": lambda args: exp.Bracket( + this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True + ), "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, "FROM_UTF8": lambda args: exp.Decode( @@ -344,20 +334,20 @@ class Presto(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "day"), + exp.Literal.string(e.text("unit") or "DAY"), _to_int( e.expression, ), e.this, ), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "day"), + exp.Literal.string(e.text("unit") or "DAY"), _to_int(e.expression * -1), e.this, ), @@ -366,6 +356,7 @@ class Presto(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, + exp.GetPath: path_to_jsonpath(), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.GroupConcat: lambda self, e: self.func( "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator") @@ -376,6 +367,7 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.ParseJSON: rename_func("JSON_PARSE"), exp.Last: _first_last_sql, + exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this), exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -446,7 +438,7 @@ class Presto(Dialect): return super().bracket_sql(expression) def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions): + if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions): self.unsupported("Struct with key-value definitions is unsupported.") return self.function_fallback_sql(expression) @@ -454,8 +446,8 @@ class Presto(Dialect): def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") - if expression.this and unit.lower().startswith("week"): - return f"({expression.this.name} * INTERVAL '7' day)" + if expression.this and unit.startswith("WEEK"): + return f"({expression.this.name} * INTERVAL '7' DAY)" return super().interval_sql(expression) def transaction_sql(self, expression: exp.Transaction) -> str: diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 7382e7c..7194d81 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -9,8 +9,8 @@ from sqlglot.dialects.dialect import ( concat_ws_to_dpipe_sql, date_delta_sql, generatedasidentitycolumnconstraint_sql, + no_tablesample_sql, rename_func, - ts_or_ds_to_date_sql, ) from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get @@ -123,6 +123,27 @@ class Redshift(Postgres): self._retreat(index) return None + def _parse_query_modifiers( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + this = super()._parse_query_modifiers(this) + + if this: + refs = set() + + for i, join in enumerate(this.args.get("joins", [])): + refs.add( + ( + this.args["from"] if i == 0 else this.args["joins"][i - 1] + ).alias_or_name.lower() + ) + table = join.this + + if isinstance(table, exp.Table): + if table.parts[0].name.lower() in refs: + table.replace(table.to_column()) + return this + class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] @@ -144,11 +165,11 @@ class Redshift(Postgres): class Generator(Postgres.Generator): LOCKING_READS_SUPPORTED = False - RENAME_TABLE_WITH_DB = False QUERY_HINTS = False VALUES_AS_TABLE = False TZ_TO_WITH_TIME_ZONE = True NVL2_SUPPORTED = True + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -184,9 +205,9 @@ class Redshift(Postgres): [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.TableSample: no_tablesample_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD"), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots @@ -198,6 +219,9 @@ class Redshift(Postgres): # Redshift supports ANY_VALUE(..) TRANSFORMS.pop(exp.AnyValue) + # Redshift supports LAST_DAY(..) + TRANSFORMS.pop(exp.LastDay) + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} def with_properties(self, properties: exp.Properties) -> str: diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 8925181..a8e4a42 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import ( rename_func, timestamptrunc_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, var_map_sql, ) from sqlglot.expressions import Literal @@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, if second_arg.is_string: # case: <string_expr> [ , <format> ] return format_time_lambda(exp.StrToTime, "snowflake")(args) - - # case: <numeric_expr> [ , <scale> ] - if second_arg.name not in ["0", "3", "9"]: - raise ValueError( - f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" - ) - - if second_arg.name == "0": - timescale = exp.UnixToTime.SECONDS - elif second_arg.name == "3": - timescale = exp.UnixToTime.MILLIS - elif second_arg.name == "9": - timescale = exp.UnixToTime.NANOS - - return exp.UnixToTime(this=first_arg, scale=timescale) + return exp.UnixToTime(this=first_arg, scale=second_arg) from sqlglot.optimizer.simplify import simplify_literals @@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: def _parse_datediff(args: t.List) -> exp.DateDiff: - return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) - - -def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") - if scale in (None, exp.UnixToTime.SECONDS): - return f"TO_TIMESTAMP({timestamp})" - if scale == exp.UnixToTime.MILLIS: - return f"TO_TIMESTAMP({timestamp}, 3)" - if scale == exp.UnixToTime.MICROS: - return f"TO_TIMESTAMP({timestamp} / 1000, 3)" - if scale == exp.UnixToTime.NANOS: - return f"TO_TIMESTAMP({timestamp}, 9)" - - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)) + ) # https://docs.snowflake.com/en/sql-reference/functions/date_part.html @@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: self._match(TokenType.COMMA) expression = self._parse_bitwise() - + this = _map_date_part(this) name = this.name.upper() + if name.startswith("EPOCH"): - if name.startswith("EPOCH_MILLISECOND"): + if name == "EPOCH_MILLISECOND": scale = 10**3 - elif name.startswith("EPOCH_MICROSECOND"): + elif name == "EPOCH_MICROSECOND": scale = 10**6 - elif name.startswith("EPOCH_NANOSECOND"): + elif name == "EPOCH_NANOSECOND": scale = 10**9 else: scale = None @@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser] return _parse +DATE_PART_MAPPING = { + "Y": "YEAR", + "YY": "YEAR", + "YYY": "YEAR", + "YYYY": "YEAR", + "YR": "YEAR", + "YEARS": "YEAR", + "YRS": "YEAR", + "MM": "MONTH", + "MON": "MONTH", + "MONS": "MONTH", + "MONTHS": "MONTH", + "D": "DAY", + "DD": "DAY", + "DAYS": "DAY", + "DAYOFMONTH": "DAY", + "WEEKDAY": "DAYOFWEEK", + "DOW": "DAYOFWEEK", + "DW": "DAYOFWEEK", + "WEEKDAY_ISO": "DAYOFWEEKISO", + "DOW_ISO": "DAYOFWEEKISO", + "DW_ISO": "DAYOFWEEKISO", + "YEARDAY": "DAYOFYEAR", + "DOY": "DAYOFYEAR", + "DY": "DAYOFYEAR", + "W": "WEEK", + "WK": "WEEK", + "WEEKOFYEAR": "WEEK", + "WOY": "WEEK", + "WY": "WEEK", + "WEEK_ISO": "WEEKISO", + "WEEKOFYEARISO": "WEEKISO", + "WEEKOFYEAR_ISO": "WEEKISO", + "Q": "QUARTER", + "QTR": "QUARTER", + "QTRS": "QUARTER", + "QUARTERS": "QUARTER", + "H": "HOUR", + "HH": "HOUR", + "HR": "HOUR", + "HOURS": "HOUR", + "HRS": "HOUR", + "M": "MINUTE", + "MI": "MINUTE", + "MIN": "MINUTE", + "MINUTES": "MINUTE", + "MINS": "MINUTE", + "S": "SECOND", + "SEC": "SECOND", + "SECONDS": "SECOND", + "SECS": "SECOND", + "MS": "MILLISECOND", + "MSEC": "MILLISECOND", + "MILLISECONDS": "MILLISECOND", + "US": "MICROSECOND", + "USEC": "MICROSECOND", + "MICROSECONDS": "MICROSECOND", + "NS": "NANOSECOND", + "NSEC": "NANOSECOND", + "NANOSEC": "NANOSECOND", + "NSECOND": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "NANOSECS": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "EPOCH": "EPOCH_SECOND", + "EPOCH_SECONDS": "EPOCH_SECOND", + "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", + "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", + "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", + "TZH": "TIMEZONE_HOUR", + "TZM": "TIMEZONE_MINUTE", +} + + +@t.overload +def _map_date_part(part: exp.Expression) -> exp.Var: + pass + + +@t.overload +def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + pass + + +def _map_date_part(part): + mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None + return exp.var(mapped) if mapped else part + + +def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: + trunc = date_trunc_to_time(args) + trunc.set("unit", _map_date_part(trunc.args["unit"])) + return trunc + + +def _parse_colon_get_path( + self: parser.Parser, this: t.Optional[exp.Expression] +) -> t.Optional[exp.Expression]: + while True: + path = self._parse_bitwise() + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH + if isinstance(path, exp.Cast): + target_type = path.to + path = path.this + else: + target_type = None + + if isinstance(path, exp.Expression): + path = exp.Literal.string(path.sql(dialect="snowflake")) + + # The extraction operator : is left-associative + this = self.expression(exp.GetPath, this=this, expression=path) + + if target_type: + this = exp.cast(this, target_type) + + if not self._match(TokenType.COLON): + break + + if self._match_set(self.RANGE_PARSERS): + this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this + + return this + + +def _parse_timestamp_from_parts(args: t.List) -> exp.Func: + if len(args) == 2: + # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, + # so we parse this into Anonymous for now instead of introducing complexity + return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) + + return exp.TimestampFromParts.from_arg_list(args) + + +def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: + """ + Snowflake doesn't allow columns referenced in UNPIVOT to be qualified, + so we need to unqualify them. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))") + >>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake")) + SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april)) + """ + if isinstance(expression, exp.Pivot) and expression.unpivot: + expression = transforms.unqualify_columns(expression) + + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -211,6 +336,8 @@ class Snowflake(Dialect): TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_SEMI_ANTI_JOIN = False + PREFER_CTE_ALIAS_COLUMN = True + TABLESAMPLE_SIZE_IS_PERCENT = True TIME_MAPPING = { "YYYY": "%Y", @@ -276,14 +403,19 @@ class Snowflake(Dialect): "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _parse_convert_timezone, - "DATE_TRUNC": date_trunc_to_time, + "DATE_TRUNC": _date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=_map_date_part(seq_get(args, 0)), ), "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "FLATTEN": exp.Explode.from_arg_list, "IFF": exp.If.from_arg_list, + "LAST_DAY": lambda args: exp.LastDay( + this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) + ), "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, @@ -293,6 +425,8 @@ class Snowflake(Dialect): "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, + "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts, + "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts, "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, @@ -301,22 +435,17 @@ class Snowflake(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, + "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(), } FUNCTION_PARSERS.pop("TRIM") - COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, - TokenType.COLON: lambda self, this, path: self.expression( - exp.Bracket, this=this, expressions=[path] - ), - } - TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME} RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), + TokenType.COLON: _parse_colon_get_path, } ALTER_PARSERS = { @@ -344,6 +473,7 @@ class Snowflake(Dialect): SHOW_PARSERS = { "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "COLUMNS": _show_parser("COLUMNS"), } STAGED_FILE_SINGLE_TOKENS = { @@ -351,8 +481,18 @@ class Snowflake(Dialect): TokenType.MOD, TokenType.SLASH, } + FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: + if is_map: + # Keys are strings in Snowflake's objects, see also: + # - https://docs.snowflake.com/en/sql-reference/data-types-semistructured + # - https://docs.snowflake.com/en/sql-reference/functions/object_construct + return self._parse_slice(self._parse_string()) + + return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True)) + def _parse_lateral(self) -> t.Optional[exp.Lateral]: lateral = super()._parse_lateral() if not lateral: @@ -440,6 +580,8 @@ class Snowflake(Dialect): scope = None scope_kind = None + like = self._parse_string() if self._match(TokenType.LIKE) else None + if self._match(TokenType.IN): if self._match_text_seq("ACCOUNT"): scope_kind = "ACCOUNT" @@ -451,7 +593,9 @@ class Snowflake(Dialect): scope_kind = "TABLE" scope = self._parse_table() - return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + return self.expression( + exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind + ) def _parse_alter_table_swap(self) -> exp.SwapTable: self._match_text_seq("WITH") @@ -489,8 +633,12 @@ class Snowflake(Dialect): "MINUS": TokenType.EXCEPT, "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, + "REMOVE": TokenType.COMMAND, "RENAME": TokenType.REPLACE, + "RM": TokenType.COMMAND, "SAMPLE": TokenType.TABLE_SAMPLE, + "SQL_DOUBLE": TokenType.DOUBLE, + "SQL_VARCHAR": TokenType.VARCHAR, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -518,6 +666,8 @@ class Snowflake(Dialect): SUPPORTS_TABLE_COPY = False COLLATE_IS_FUNC = True LIMIT_ONLY_LITERALS = True + JSON_KEY_VALUE_PAIR_SEP = "," + INSERT_OVERWRITE = " OVERWRITE INTO" TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -545,6 +695,8 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), + exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]", + exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -557,6 +709,7 @@ class Snowflake(Dialect): exp.PercentileDisc: transforms.preprocess( [transforms.add_within_group_for_percentiles] ), + exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]), exp.RegexpILike: _regexpilike_sql, exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( @@ -578,6 +731,9 @@ class Snowflake(Dialect): *(arg for expression in e.expressions for arg in expression.flatten()), ), exp.Stuff: rename_func("INSERT"), + exp.TimestampDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.unit, e.expression, e.this + ), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( @@ -589,8 +745,7 @@ class Snowflake(Dialect): exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), - exp.UnixToTime: _unix_to_time_sql, + exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.Xor: rename_func("BOOLXOR"), @@ -612,6 +767,14 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + milli = expression.args.get("milli") + if milli is not None: + milli_to_nano = milli.pop() * exp.Literal.number(1000000) + expression.set("nano", milli_to_nano) + + return rename_func("TIMESTAMP_FROM_PARTS")(self, expression) + def trycast_sql(self, expression: exp.TryCast) -> str: value = expression.this @@ -657,6 +820,9 @@ class Snowflake(Dialect): return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: + like = self.sql(expression, "like") + like = f" LIKE {like}" if like else "" + scope = self.sql(expression, "scope") scope = f" {scope}" if scope else "" @@ -664,7 +830,7 @@ class Snowflake(Dialect): if scope_kind: scope_kind = f" IN {scope_kind}" - return f"SHOW {expression.name}{scope_kind}{scope}" + return f"SHOW {expression.name}{like}{scope_kind}{scope}" def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index aa09f53..e27ba18 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -48,11 +48,8 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str return f"TIMESTAMP_MILLIS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"TIMESTAMP_MICROS({timestamp})" - if scale == exp.UnixToTime.NANOS: - return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))" def _unalias_pivot(expression: exp.Expression) -> exp.Expression: @@ -93,12 +90,7 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) """ if isinstance(expression, exp.Pivot): - expression.args["field"].transform( - lambda node: exp.column(node.output_name, quoted=node.this.quoted) - if isinstance(node, exp.Column) - else node, - copy=False, - ) + expression.set("field", transforms.unqualify_columns(expression.args["field"])) return expression @@ -234,7 +226,7 @@ class Spark2(Hive): def struct_sql(self, expression: exp.Struct) -> str: args = [] for arg in expression.expressions: - if isinstance(arg, self.KEY_VALUE_DEFINITONS): + if isinstance(arg, self.KEY_VALUE_DEFINITIONS): if isinstance(arg, exp.Bracket): args.append(exp.alias_(arg.this, arg.expressions[0].name)) else: diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 9bac51c..244a96e 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -78,6 +78,7 @@ class SQLite(Dialect): **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } + STRING_ALIASES = True class Generator(generator.Generator): JOIN_HINTS = False diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 0ccc567..6dbad15 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -175,6 +175,8 @@ class Teradata(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + TABLESAMPLE_KEYWORDS = "SAMPLE" + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -214,7 +216,10 @@ class Teradata(Dialect): return self.cast_sql(expression, safe_prefix="TRY") def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + self, + expression: exp.TableSample, + sep: str = " AS ", + tablesample_keyword: t.Optional[str] = None, ) -> str: return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}" diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 3682ac7..eddb70a 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -1,6 +1,7 @@ from __future__ import annotations from sqlglot import exp +from sqlglot.dialects.dialect import merge_without_target_sql from sqlglot.dialects.presto import Presto @@ -11,6 +12,7 @@ class Trino(Presto): TRANSFORMS = { **Presto.Generator.TRANSFORMS, exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.Merge: merge_without_target_sql, } class Tokenizer(Presto.Tokenizer): diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 165a703..b9c347c 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, parse_date_delta, + path_to_jsonpath, rename_func, timestrtotime_sql, - ts_or_ds_to_date_sql, + trim_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr(this=this, format=fmt, culture=culture) -def _parse_eomonth(args: t.List) -> exp.Expression: - date = seq_get(args, 0) +def _parse_eomonth(args: t.List) -> exp.LastDay: + date = exp.TsOrDsToDate(this=seq_get(args, 0)) month_lag = seq_get(args, 1) - unit = DATE_DELTA_INTERVAL.get("month") if month_lag is None: - return exp.LastDateOfMonth(this=date) + this: exp.Expression = date + else: + unit = DATE_DELTA_INTERVAL.get("month") + this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit)) - # Remove month lag argument in parser as its compared with the number of arguments of the resulting class - args.remove(month_lag) - - return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) + return exp.LastDay(this=this) def _parse_hashbytes(args: t.List) -> exp.Expression: @@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: return exp.func("HASHBYTES", *args) -DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"} +DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"} def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: - fmt = ( - expression.args["format"] - if isinstance(expression, exp.NumberToStr) - else exp.Literal.string( - format_time( - expression.text("format"), - t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), - ) - ) - ) + fmt = expression.args["format"] - # There is no format for "quarter" - if fmt.name.lower() in DATEPART_ONLY_FORMATS: - return self.func("DATEPART", fmt.name, expression.this) + if not isinstance(expression, exp.NumberToStr): + if fmt.is_string: + mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING) - return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) + name = (mapped_fmt or "").upper() + if name in DATEPART_ONLY_FORMATS: + return self.func("DATEPART", name, expression.this) + + fmt_sql = self.sql(exp.Literal.string(mapped_fmt)) + else: + fmt_sql = self.format_time(expression) or self.sql(fmt) + else: + fmt_sql = self.sql(fmt) + + return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture")) def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: @@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: return expression +# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax +def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: + return exp.TimestampFromParts( + year=seq_get(args, 0), + month=seq_get(args, 1), + day=seq_get(args, 2), + hour=seq_get(args, 3), + min=seq_get(args, 4), + sec=seq_get(args, 5), + milli=seq_get(args, 6), + ) + + +# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax +def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: + return exp.TimeFromParts( + hour=seq_get(args, 0), + min=seq_get(args, 1), + sec=seq_get(args, 2), + fractions=seq_get(args, 3), + precision=seq_get(args, 4), + ) + + class TSQL(Dialect): NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" @@ -352,7 +377,7 @@ class TSQL(Dialect): } class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = ['"', ("[", "]")] + IDENTIFIERS = [("[", "]"), '"'] QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] VAR_SINGLE_TOKENS = {"@", "$", "#"} @@ -362,6 +387,7 @@ class TSQL(Dialect): "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DECLARE": TokenType.COMMAND, + "EXEC": TokenType.COMMAND, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "NTEXT": TokenType.TEXT, @@ -397,6 +423,7 @@ class TSQL(Dialect): "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), + "DATETIMEFROMPARTS": _parse_datetimefromparts, "EOMONTH": _parse_eomonth, "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, @@ -411,6 +438,7 @@ class TSQL(Dialect): "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, "SYSTEM_USER": exp.CurrentUser.from_arg_list, + "TIMEFROMPARTS": _parse_timefromparts, } JOIN_HINTS = { @@ -440,6 +468,7 @@ class TSQL(Dialect): LOG_DEFAULTS_TO_LN = True ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False + STRING_ALIASES = True def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -630,8 +659,10 @@ class TSQL(Dialect): COMPUTED_COLUMN_WITH_TYPE = False CTE_RECURSIVE_KEYWORD_REQUIRED = False ENSURE_BOOLS = True - NULL_ORDERING_SUPPORTED = False + NULL_ORDERING_SUPPORTED = None SUPPORTS_SINGLE_ARG_CONCAT = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -667,13 +698,16 @@ class TSQL(Dialect): exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, + exp.GetPath: path_to_jsonpath("JSON_VALUE"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, exp.NumberToStr: _format_sql, + exp.ParseJSON: lambda self, e: self.sql(e, "this"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, @@ -689,9 +723,9 @@ class TSQL(Dialect): exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.Trim: trim_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } TRANSFORMS.pop(exp.ReturnsProperty) @@ -701,6 +735,46 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def lateral_op(self, expression: exp.Lateral) -> str: + cross_apply = expression.args.get("cross_apply") + if cross_apply is True: + return "CROSS APPLY" + if cross_apply is False: + return "OUTER APPLY" + + # TODO: perhaps we can check if the parent is a Join and transpile it appropriately + self.unsupported("LATERAL clause is not supported.") + return "LATERAL" + + def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: + nano = expression.args.get("nano") + if nano is not None: + nano.pop() + self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.") + + if expression.args.get("fractions") is None: + expression.set("fractions", exp.Literal.number(0)) + if expression.args.get("precision") is None: + expression.set("precision", exp.Literal.number(0)) + + return rename_func("TIMEFROMPARTS")(self, expression) + + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + zone = expression.args.get("zone") + if zone is not None: + zone.pop() + self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.") + + nano = expression.args.get("nano") + if nano is not None: + nano.pop() + self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.") + + if expression.args.get("milli") is None: + expression.set("milli", exp.Literal.number(0)) + + return rename_func("DATETIMEFROMPARTS")(self, expression) + def set_operation(self, expression: exp.Union, op: str) -> str: limit = expression.args.get("limit") if limit: diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index b79a551..6c01edc 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -132,11 +132,10 @@ def ordered(this, desc, nulls_first): @null_if_any def interval(this, unit): - unit = unit.lower() - plural = unit + "s" + plural = unit + "S" if plural in Generator.TIME_PART_SINGULARS: unit = plural - return datetime.timedelta(**{unit: float(this)}) + return datetime.timedelta(**{unit.lower(): float(this)}) @null_if_any("this", "expression") @@ -176,6 +175,7 @@ ENV = { "DOT": null_if_any(lambda e, this: e[this]), "EQ": null_if_any(lambda this, e: this == e), "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), + "GETPATH": null_if_any(lambda this, e: this.get(e)), "GT": null_if_any(lambda this, e: this > e), "GTE": null_if_any(lambda this, e: this >= e), "IF": lambda predicate, true, false: true if predicate else false, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index ea2255d..ddad8f8 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -16,6 +16,7 @@ import datetime import math import numbers import re +import textwrap import typing as t from collections import deque from copy import deepcopy @@ -35,6 +36,8 @@ from sqlglot.helper import ( from sqlglot.tokens import Token if t.TYPE_CHECKING: + from typing_extensions import Literal as Lit + from sqlglot.dialects.dialect import DialectType @@ -242,6 +245,9 @@ class Expression(metaclass=_Expression): def is_type(self, *dtypes) -> bool: return self.type is not None and self.type.is_type(*dtypes) + def is_leaf(self) -> bool: + return not any(isinstance(v, (Expression, list)) for v in self.args.values()) + @property def meta(self) -> t.Dict[str, t.Any]: if self._meta is None: @@ -497,7 +503,14 @@ class Expression(metaclass=_Expression): return self.sql() def __repr__(self) -> str: - return self._to_s() + return _to_s(self) + + def to_s(self) -> str: + """ + Same as __repr__, but includes additional information which can be useful + for debugging, like empty or missing args and the AST nodes' object IDs. + """ + return _to_s(self, verbose=True) def sql(self, dialect: DialectType = None, **opts) -> str: """ @@ -514,30 +527,6 @@ class Expression(metaclass=_Expression): return Dialect.get_or_raise(dialect).generate(self, **opts) - def _to_s(self, hide_missing: bool = True, level: int = 0) -> str: - indent = "" if not level else "\n" - indent += "".join([" "] * level) - left = f"({self.key.upper()} " - - args: t.Dict[str, t.Any] = { - k: ", ".join( - v._to_s(hide_missing=hide_missing, level=level + 1) - if hasattr(v, "_to_s") - else str(v) - for v in ensure_list(vs) - if v is not None - ) - for k, vs in self.args.items() - } - args["comments"] = self.comments - args["type"] = self.type - args = {k: v for k, v in args.items() if v or not hide_missing} - - right = ", ".join(f"{k}: {v}" for k, v in args.items()) - right += ")" - - return indent + left + right - def transform(self, fun, *args, copy=True, **kwargs): """ Recursively visits all tree nodes (excluding already transformed ones) @@ -580,8 +569,9 @@ class Expression(metaclass=_Expression): For example:: >>> tree = Select().select("x").from_("tbl") - >>> tree.find(Column).replace(Column(this="y")) - (COLUMN this: y) + >>> tree.find(Column).replace(column("y")) + Column( + this=Identifier(this=y, quoted=False)) >>> tree.sql() 'SELECT y FROM tbl' @@ -831,6 +821,9 @@ class Expression(metaclass=_Expression): div.args["safe"] = safe return div + def desc(self, nulls_first: bool = False) -> Ordered: + return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first) + def __lt__(self, other: t.Any) -> LT: return self._binop(LT, other) @@ -1109,7 +1102,7 @@ class Clone(Expression): class Describe(Expression): - arg_types = {"this": True, "kind": False, "expressions": False} + arg_types = {"this": True, "extended": False, "kind": False, "expressions": False} class Kill(Expression): @@ -1124,6 +1117,10 @@ class Set(Expression): arg_types = {"expressions": False, "unset": False, "tag": False} +class Heredoc(Expression): + arg_types = {"this": True, "tag": False} + + class SetItem(Expression): arg_types = { "this": False, @@ -1937,7 +1934,13 @@ class Join(Expression): class Lateral(UDTF): - arg_types = {"this": True, "view": False, "outer": False, "alias": False} + arg_types = { + "this": True, + "view": False, + "outer": False, + "alias": False, + "cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY + } class MatchRecognize(Expression): @@ -1964,7 +1967,12 @@ class Offset(Expression): class Order(Expression): - arg_types = {"this": False, "expressions": True, "interpolate": False} + arg_types = { + "this": False, + "expressions": True, + "interpolate": False, + "siblings": False, + } # https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier @@ -2002,6 +2010,11 @@ class AutoIncrementProperty(Property): arg_types = {"this": True} +# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html +class AutoRefreshProperty(Property): + arg_types = {"this": True} + + class BlockCompressionProperty(Property): arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True} @@ -2259,6 +2272,10 @@ class SortKeyProperty(Property): arg_types = {"this": True, "compound": False} +class SqlReadWriteProperty(Property): + arg_types = {"this": True} + + class SqlSecurityProperty(Property): arg_types = {"definer": True} @@ -2543,7 +2560,6 @@ class Table(Expression): "version": False, "format": False, "pattern": False, - "index": False, "ordinality": False, "when": False, } @@ -2585,6 +2601,14 @@ class Table(Expression): return parts + def to_column(self, copy: bool = True) -> Alias | Column | Dot: + parts = self.parts + col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore + alias = self.args.get("alias") + if alias: + col = alias_(col, alias.this, copy=copy) + return col + class Union(Subqueryable): arg_types = { @@ -2694,6 +2718,14 @@ class Unnest(UDTF): "offset": False, } + @property + def selects(self) -> t.List[Expression]: + columns = super().selects + offset = self.args.get("offset") + if offset: + columns = columns + [to_identifier("offset") if offset is True else offset] + return columns + class Update(Expression): arg_types = { @@ -3368,7 +3400,7 @@ class Select(Subqueryable): return Create( this=table_expression, - kind="table", + kind="TABLE", expression=instance, properties=properties_expression, ) @@ -3488,7 +3520,6 @@ class TableSample(Expression): "rows": False, "size": False, "seed": False, - "kind": False, } @@ -3517,6 +3548,10 @@ class Pivot(Expression): "include_nulls": False, } + @property + def unpivot(self) -> bool: + return bool(self.args.get("unpivot")) + class Window(Condition): arg_types = { @@ -3604,6 +3639,7 @@ class DataType(Expression): BOOLEAN = auto() CHAR = auto() DATE = auto() + DATE32 = auto() DATEMULTIRANGE = auto() DATERANGE = auto() DATETIME = auto() @@ -3631,6 +3667,8 @@ class DataType(Expression): INTERVAL = auto() IPADDRESS = auto() IPPREFIX = auto() + IPV4 = auto() + IPV6 = auto() JSON = auto() JSONB = auto() LONGBLOB = auto() @@ -3729,6 +3767,7 @@ class DataType(Expression): Type.TIMESTAMP_MS, Type.TIMESTAMP_NS, Type.DATE, + Type.DATE32, Type.DATETIME, Type.DATETIME64, } @@ -4100,6 +4139,12 @@ class Alias(Expression): return self.alias +# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but +# other dialects require identifiers. This enables us to transpile between them easily. +class PivotAlias(Alias): + pass + + class Aliases(Expression): arg_types = {"this": True, "expressions": True} @@ -4108,6 +4153,11 @@ class Aliases(Expression): return self.expressions +# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html +class AtIndex(Expression): + arg_types = {"this": True, "expression": True} + + class AtTimeZone(Expression): arg_types = {"this": True, "zone": True} @@ -4154,16 +4204,16 @@ class TimeUnit(Expression): arg_types = {"unit": False} UNABBREVIATED_UNIT_NAME = { - "d": "day", - "h": "hour", - "m": "minute", - "ms": "millisecond", - "ns": "nanosecond", - "q": "quarter", - "s": "second", - "us": "microsecond", - "w": "week", - "y": "year", + "D": "DAY", + "H": "HOUR", + "M": "MINUTE", + "MS": "MILLISECOND", + "NS": "NANOSECOND", + "Q": "QUARTER", + "S": "SECOND", + "US": "MICROSECOND", + "W": "WEEK", + "Y": "YEAR", } VAR_LIKE = (Column, Literal, Var) @@ -4171,9 +4221,11 @@ class TimeUnit(Expression): def __init__(self, **args): unit = args.get("unit") if isinstance(unit, self.VAR_LIKE): - args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name) + args["unit"] = Var( + this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper() + ) elif isinstance(unit, Week): - unit.set("this", Var(this=unit.this.name)) + unit.set("this", Var(this=unit.this.name.upper())) super().__init__(**args) @@ -4301,6 +4353,20 @@ class Anonymous(Func): is_var_len_args = True +class AnonymousAggFunc(AggFunc): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators +class CombinedAggFunc(AnonymousAggFunc): + arg_types = {"this": True, "expressions": False, "parts": True} + + +class CombinedParameterizedAgg(ParameterizedAgg): + arg_types = {"this": True, "expressions": True, "params": True, "parts": True} + + # https://docs.snowflake.com/en/sql-reference/functions/hll # https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html class Hll(AggFunc): @@ -4381,7 +4447,7 @@ class ArraySort(Func): class ArraySum(Func): - pass + arg_types = {"this": True, "expression": False} class ArrayUnionAgg(AggFunc): @@ -4498,7 +4564,7 @@ class Count(AggFunc): class CountIf(AggFunc): - pass + _sql_names = ["COUNT_IF", "COUNTIF"] class CurrentDate(Func): @@ -4537,6 +4603,17 @@ class DateDiff(Func, TimeUnit): class DateTrunc(Func): arg_types = {"unit": True, "this": True, "zone": False} + def __init__(self, **args): + unit = args.get("unit") + if isinstance(unit, TimeUnit.VAR_LIKE): + args["unit"] = Literal.string( + (TimeUnit.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper() + ) + elif isinstance(unit, Week): + unit.set("this", Literal.string(unit.this.name.upper())) + + super().__init__(**args) + @property def unit(self) -> Expression: return self.args["unit"] @@ -4582,8 +4659,9 @@ class MonthsBetween(Func): arg_types = {"this": True, "expression": True, "roundoff": False} -class LastDateOfMonth(Func): - pass +class LastDay(Func, TimeUnit): + _sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"] + arg_types = {"this": True, "unit": False} class Extract(Func): @@ -4627,10 +4705,22 @@ class TimeTrunc(Func, TimeUnit): class DateFromParts(Func): - _sql_names = ["DATEFROMPARTS"] + _sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"] arg_types = {"year": True, "month": True, "day": True} +class TimeFromParts(Func): + _sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"] + arg_types = { + "hour": True, + "min": True, + "sec": True, + "nano": False, + "fractions": False, + "precision": False, + } + + class DateStrToDate(Func): pass @@ -4754,6 +4844,16 @@ class JSONObject(Func): } +class JSONObjectAgg(AggFunc): + arg_types = { + "expressions": False, + "null_handling": False, + "unique_keys": False, + "return_type": False, + "encoding": False, + } + + # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html class JSONArray(Func): arg_types = { @@ -4841,6 +4941,15 @@ class ParseJSON(Func): is_var_len_args = True +# https://docs.snowflake.com/en/sql-reference/functions/get_path +class GetPath(Func): + arg_types = {"this": True, "expression": True} + + @property + def output_name(self) -> str: + return self.expression.output_name + + class Least(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -5026,7 +5135,7 @@ class RegexpReplace(Func): arg_types = { "this": True, "expression": True, - "replacement": True, + "replacement": False, "position": False, "occurrence": False, "parameters": False, @@ -5052,8 +5161,10 @@ class Repeat(Func): arg_types = {"this": True, "times": True} +# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16 +# tsql third argument function == trunctaion if not 0 class Round(Func): - arg_types = {"this": True, "decimals": False} + arg_types = {"this": True, "decimals": False, "truncate": False} class RowNumber(Func): @@ -5228,6 +5339,10 @@ class TsOrDsToDate(Func): arg_types = {"this": True, "format": False} +class TsOrDsToTime(Func): + pass + + class TsOrDiToDi(Func): pass @@ -5236,6 +5351,11 @@ class Unhex(Func): pass +# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date +class UnixDate(Func): + pass + + class UnixToStr(Func): arg_types = {"this": True, "format": False} @@ -5245,10 +5365,16 @@ class UnixToStr(Func): class UnixToTime(Func): arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False} - SECONDS = Literal.string("seconds") - MILLIS = Literal.string("millis") - MICROS = Literal.string("micros") - NANOS = Literal.string("nanos") + SECONDS = Literal.number(0) + DECIS = Literal.number(1) + CENTIS = Literal.number(2) + MILLIS = Literal.number(3) + DECIMILLIS = Literal.number(4) + CENTIMILLIS = Literal.number(5) + MICROS = Literal.number(6) + DECIMICROS = Literal.number(7) + CENTIMICROS = Literal.number(8) + NANOS = Literal.number(9) class UnixToTimeStr(Func): @@ -5256,8 +5382,7 @@ class UnixToTimeStr(Func): class TimestampFromParts(Func): - """Constructs a timestamp given its constituent parts.""" - + _sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"] arg_types = { "year": True, "month": True, @@ -5265,6 +5390,9 @@ class TimestampFromParts(Func): "hour": True, "min": True, "sec": True, + "nano": False, + "zone": False, + "milli": False, } @@ -5358,9 +5486,9 @@ def maybe_parse( Example: >>> maybe_parse("1") - (LITERAL this: 1, is_string: False) + Literal(this=1, is_string=False) >>> maybe_parse(to_identifier("x")) - (IDENTIFIER this: x, quoted: False) + Identifier(this=x, quoted=False) Args: sql_or_expression: the SQL code string or an expression @@ -5407,6 +5535,39 @@ def maybe_copy(instance, copy=True): return instance.copy() if copy and instance else instance +def _to_s(node: t.Any, verbose: bool = False, level: int = 0) -> str: + """Generate a textual representation of an Expression tree""" + indent = "\n" + (" " * (level + 1)) + delim = f",{indent}" + + if isinstance(node, Expression): + args = {k: v for k, v in node.args.items() if (v is not None and v != []) or verbose} + + if (node.type or verbose) and not isinstance(node, DataType): + args["_type"] = node.type + if node.comments or verbose: + args["_comments"] = node.comments + + if verbose: + args["_id"] = id(node) + + # Inline leaves for a more compact representation + if node.is_leaf(): + indent = "" + delim = ", " + + items = delim.join([f"{k}={_to_s(v, verbose, level + 1)}" for k, v in args.items()]) + return f"{node.__class__.__name__}({indent}{items})" + + if isinstance(node, list): + items = delim.join(_to_s(i, verbose, level + 1) for i in node) + items = f"{indent}{items}" if items else "" + return f"[{items}]" + + # Indent multiline strings to match the current level + return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines()) + + def _is_wrong_expression(expression, into): return isinstance(expression, Expression) and not isinstance(expression, into) @@ -5816,7 +5977,7 @@ def delete( def insert( expression: ExpOrStr, into: ExpOrStr, - columns: t.Optional[t.Sequence[ExpOrStr]] = None, + columns: t.Optional[t.Sequence[str | Identifier]] = None, overwrite: t.Optional[bool] = None, returning: t.Optional[ExpOrStr] = None, dialect: DialectType = None, @@ -5847,15 +6008,7 @@ def insert( this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts) if columns: - this = _apply_list_builder( - *columns, - instance=Schema(this=this), - arg="expressions", - into=Identifier, - copy=False, - dialect=dialect, - **opts, - ) + this = Schema(this=this, expressions=[to_identifier(c, copy=copy) for c in columns]) insert = Insert(this=this, expression=expr, overwrite=overwrite) @@ -6073,7 +6226,7 @@ def to_interval(interval: str | Literal) -> Interval: return Interval( this=Literal.string(interval_parts.group(1)), - unit=Var(this=interval_parts.group(2)), + unit=Var(this=interval_parts.group(2).upper()), ) @@ -6219,13 +6372,44 @@ def subquery( return Select().from_(expression, dialect=dialect, **opts) +@t.overload +def column( + col: str | Identifier, + table: t.Optional[str | Identifier] = None, + db: t.Optional[str | Identifier] = None, + catalog: t.Optional[str | Identifier] = None, + *, + fields: t.Collection[t.Union[str, Identifier]], + quoted: t.Optional[bool] = None, + copy: bool = True, +) -> Dot: + pass + + +@t.overload def column( col: str | Identifier, table: t.Optional[str | Identifier] = None, db: t.Optional[str | Identifier] = None, catalog: t.Optional[str | Identifier] = None, + *, + fields: Lit[None] = None, quoted: t.Optional[bool] = None, + copy: bool = True, ) -> Column: + pass + + +def column( + col, + table=None, + db=None, + catalog=None, + *, + fields=None, + quoted=None, + copy=True, +): """ Build a Column. @@ -6234,18 +6418,24 @@ def column( table: Table name. db: Database name. catalog: Catalog name. + fields: Additional fields using dots. quoted: Whether to force quotes on the column's identifiers. + copy: Whether or not to copy identifiers if passed in. Returns: The new Column instance. """ - return Column( - this=to_identifier(col, quoted=quoted), - table=to_identifier(table, quoted=quoted), - db=to_identifier(db, quoted=quoted), - catalog=to_identifier(catalog, quoted=quoted), + this = Column( + this=to_identifier(col, quoted=quoted, copy=copy), + table=to_identifier(table, quoted=quoted, copy=copy), + db=to_identifier(db, quoted=quoted, copy=copy), + catalog=to_identifier(catalog, quoted=quoted, copy=copy), ) + if fields: + this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields))) + return this + def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast: """Cast an expression to a data type. @@ -6333,10 +6523,10 @@ def var(name: t.Optional[ExpOrStr]) -> Var: Example: >>> repr(var('x')) - '(VAR this: x)' + 'Var(this=x)' >>> repr(var(column('x', table='y'))) - '(VAR this: x)' + 'Var(this=x)' Args: name: The name of the var or an expression who's name will become the var. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b0e83d2..977185f 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -68,6 +68,7 @@ class Generator: exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", + exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS", exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", @@ -96,6 +97,7 @@ class Generator: exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", + exp.SqlReadWriteProperty: lambda self, e: e.name, exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.StabilityProperty: lambda self, e: e.name, exp.TemporaryProperty: lambda self, e: f"TEMPORARY", @@ -110,7 +112,8 @@ class Generator: } # Whether or not null ordering is supported in order by - NULL_ORDERING_SUPPORTED = True + # True: Full Support, None: No support, False: No support in window specifications + NULL_ORDERING_SUPPORTED: t.Optional[bool] = True # Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported LOCKING_READS_SUPPORTED = False @@ -133,12 +136,6 @@ class Generator: # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs INTERVAL_ALLOWS_PLURAL_FORM = True - # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI - TABLESAMPLE_WITH_METHOD = True - - # Whether or not to treat the number in TABLESAMPLE (50) as a percentage - TABLESAMPLE_SIZE_IS_PERCENT = False - # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" @@ -219,6 +216,18 @@ class Generator: # Whether or not parentheses are required around the table sample's expression TABLESAMPLE_REQUIRES_PARENS = True + # Whether or not a table sample clause's size needs to be followed by the ROWS keyword + TABLESAMPLE_SIZE_IS_ROWS = True + + # The keyword(s) to use when generating a sample clause + TABLESAMPLE_KEYWORDS = "TABLESAMPLE" + + # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI + TABLESAMPLE_WITH_METHOD = True + + # The keyword to use when specifying the seed of a sample clause + TABLESAMPLE_SEED_KEYWORD = "SEED" + # Whether or not COLLATE is a function instead of a binary operator COLLATE_IS_FUNC = False @@ -234,6 +243,27 @@ class Generator: # Whether or not CONCAT requires >1 arguments SUPPORTS_SINGLE_ARG_CONCAT = True + # Whether or not LAST_DAY function supports a date part argument + LAST_DAY_SUPPORTS_DATE_PART = True + + # Whether or not named columns are allowed in table aliases + SUPPORTS_TABLE_ALIAS_COLUMNS = True + + # Whether or not UNPIVOT aliases are Identifiers (False means they're Literals) + UNPIVOT_ALIASES_ARE_IDENTIFIERS = True + + # What delimiter to use for separating JSON key/value pairs + JSON_KEY_VALUE_PAIR_SEP = ":" + + # INSERT OVERWRITE TABLE x override + INSERT_OVERWRITE = " OVERWRITE TABLE" + + # Whether or not the SELECT .. INTO syntax is used instead of CTAS + SUPPORTS_SELECT_INTO = False + + # Whether or not UNLOGGED tables can be created + SUPPORTS_UNLOGGED_TABLES = False + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -252,15 +282,15 @@ class Generator: } TIME_PART_SINGULARS = { - "microseconds": "microsecond", - "seconds": "second", - "minutes": "minute", - "hours": "hour", - "days": "day", - "weeks": "week", - "months": "month", - "quarters": "quarter", - "years": "year", + "MICROSECONDS": "MICROSECOND", + "SECONDS": "SECOND", + "MINUTES": "MINUTE", + "HOURS": "HOUR", + "DAYS": "DAY", + "WEEKS": "WEEK", + "MONTHS": "MONTH", + "QUARTERS": "QUARTER", + "YEARS": "YEAR", } TOKEN_MAPPING: t.Dict[TokenType, str] = {} @@ -272,6 +302,7 @@ class Generator: PROPERTIES_LOCATION = { exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, + exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA, exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, exp.ChecksumProperty: exp.Properties.Location.POST_NAME, @@ -323,6 +354,7 @@ class Generator: exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, + exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, @@ -370,7 +402,7 @@ class Generator: # Expressions that need to have all CTEs under them bubbled up to them EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set() - KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice) + KEY_VALUE_DEFINITIONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -775,7 +807,7 @@ class Generator: return self.sql(expression, "this") def create_sql(self, expression: exp.Create) -> str: - kind = self.sql(expression, "kind").upper() + kind = self.sql(expression, "kind") properties = expression.args.get("properties") properties_locs = self.locate_properties(properties) if properties else defaultdict() @@ -868,7 +900,12 @@ class Generator: return f"{shallow}{keyword} {this}" def describe_sql(self, expression: exp.Describe) -> str: - return f"DESCRIBE {self.sql(expression, 'this')}" + extended = " EXTENDED" if expression.args.get("extended") else "" + return f"DESCRIBE{extended} {self.sql(expression, 'this')}" + + def heredoc_sql(self, expression: exp.Heredoc) -> str: + tag = self.sql(expression, "tag") + return f"${tag}${self.sql(expression, 'this')}${tag}$" def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: with_ = self.sql(expression, "with") @@ -895,6 +932,10 @@ class Generator: columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" + if columns and not self.SUPPORTS_TABLE_ALIAS_COLUMNS: + columns = "" + self.unsupported("Named columns are not supported in table alias.") + if not alias and not self.dialect.UNNEST_COLUMN_ONLY: alias = "_t" @@ -1027,7 +1068,7 @@ class Generator: def fetch_sql(self, expression: exp.Fetch) -> str: direction = expression.args.get("direction") - direction = f" {direction.upper()}" if direction else "" + direction = f" {direction}" if direction else "" count = expression.args.get("count") count = f" {count}" if count else "" if expression.args.get("percent"): @@ -1318,7 +1359,7 @@ class Generator: if isinstance(expression.this, exp.Directory): this = " OVERWRITE" if overwrite else " INTO" else: - this = " OVERWRITE TABLE" if overwrite else " INTO" + this = self.INSERT_OVERWRITE if overwrite else " INTO" alternative = expression.args.get("alternative") alternative = f" OR {alternative}" if alternative else "" @@ -1365,10 +1406,10 @@ class Generator: return f"KILL{kind}{this}" def pseudotype_sql(self, expression: exp.PseudoType) -> str: - return expression.name.upper() + return expression.name def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str: - return expression.name.upper() + return expression.name def onconflict_sql(self, expression: exp.OnConflict) -> str: conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" @@ -1445,9 +1486,6 @@ class Generator: pattern = f", PATTERN => {pattern}" if pattern else "" file_format = f" (FILE_FORMAT => {file_format}{pattern})" - index = self.sql(expression, "index") - index = f" AT {index}" if index else "" - ordinality = expression.args.get("ordinality") or "" if ordinality: ordinality = f" WITH ORDINALITY{alias}" @@ -1457,10 +1495,13 @@ class Generator: if when: table = f"{table} {when}" - return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}" + return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}" def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + self, + expression: exp.TableSample, + sep: str = " AS ", + tablesample_keyword: t.Optional[str] = None, ) -> str: if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias: table = expression.this.copy() @@ -1472,30 +1513,30 @@ class Generator: alias = "" method = self.sql(expression, "method") - method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else "" + method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else "" numerator = self.sql(expression, "bucket_numerator") denominator = self.sql(expression, "bucket_denominator") field = self.sql(expression, "bucket_field") field = f" ON {field}" if field else "" bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else "" - percent = self.sql(expression, "percent") - percent = f"{percent} PERCENT" if percent else "" - rows = self.sql(expression, "rows") - rows = f"{rows} ROWS" if rows else "" + seed = self.sql(expression, "seed") + seed = f" {self.TABLESAMPLE_SEED_KEYWORD} ({seed})" if seed else "" size = self.sql(expression, "size") - if size and self.TABLESAMPLE_SIZE_IS_PERCENT: - size = f"{size} PERCENT" + if size and self.TABLESAMPLE_SIZE_IS_ROWS: + size = f"{size} ROWS" - seed = self.sql(expression, "seed") - seed = f" {seed_prefix} ({seed})" if seed else "" - kind = expression.args.get("kind", "TABLESAMPLE") + percent = self.sql(expression, "percent") + if percent and not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT: + percent = f"{percent} PERCENT" - expr = f"{bucket}{percent}{rows}{size}" + expr = f"{bucket}{percent}{size}" if self.TABLESAMPLE_REQUIRES_PARENS: expr = f"({expr})" - return f"{this} {kind} {method}{expr}{seed}{alias}" + return ( + f"{this} {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}{alias}" + ) def pivot_sql(self, expression: exp.Pivot) -> str: expressions = self.expressions(expression, flat=True) @@ -1513,8 +1554,7 @@ class Generator: alias = self.sql(expression, "alias") alias = f" AS {alias}" if alias else "" - unpivot = expression.args.get("unpivot") - direction = "UNPIVOT" if unpivot else "PIVOT" + direction = "UNPIVOT" if expression.unpivot else "PIVOT" field = self.sql(expression, "field") include_nulls = expression.args.get("include_nulls") if include_nulls is not None: @@ -1675,7 +1715,8 @@ class Generator: if not on_sql and using: on_sql = csv(*(self.sql(column) for column in using)) - this_sql = self.sql(expression, "this") + this = expression.this + this_sql = self.sql(this) if on_sql: on_sql = self.indent(on_sql, skip_first=True) @@ -1685,6 +1726,9 @@ class Generator: else: on_sql = f"{space}ON {on_sql}" elif not op_sql: + if isinstance(this, exp.Lateral) and this.args.get("cross_apply") is not None: + return f" {this_sql}" + return f", {this_sql}" op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" @@ -1695,6 +1739,19 @@ class Generator: args = f"({args})" if len(args.split(",")) > 1 else args return f"{args} {arrow_sep} {self.sql(expression, 'this')}" + def lateral_op(self, expression: exp.Lateral) -> str: + cross_apply = expression.args.get("cross_apply") + + # https://www.mssqltips.com/sqlservertip/1958/sql-server-cross-apply-and-outer-apply/ + if cross_apply is True: + op = "INNER JOIN " + elif cross_apply is False: + op = "LEFT JOIN " + else: + op = "" + + return f"{op}LATERAL" + def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") @@ -1708,7 +1765,7 @@ class Generator: alias = self.sql(expression, "alias") alias = f" AS {alias}" if alias else "" - return f"LATERAL {this}{alias}" + return f"{self.lateral_op(expression)} {this}{alias}" def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: this = self.sql(expression, "this") @@ -1805,7 +1862,8 @@ class Generator: def order_sql(self, expression: exp.Order, flat: bool = False) -> str: this = self.sql(expression, "this") this = f"{this} " if this else this - order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore + siblings = "SIBLINGS " if expression.args.get("siblings") else "" + order = self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore interpolated_values = [ f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}" for named_expression in expression.args.get("interpolate") or [] @@ -1860,9 +1918,21 @@ class Generator: # If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else "" - this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" - nulls_sort_change = "" + window = expression.find_ancestor(exp.Window, exp.Select) + if isinstance(window, exp.Window) and window.args.get("spec"): + self.unsupported( + f"'{nulls_sort_change.strip()}' translation not supported in window functions" + ) + nulls_sort_change = "" + elif self.NULL_ORDERING_SUPPORTED is None: + if expression.this.is_int: + self.unsupported( + f"'{nulls_sort_change.strip()}' translation not supported with positional ordering" + ) + else: + null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else "" + this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" + nulls_sort_change = "" with_fill = self.sql(expression, "with_fill") with_fill = f" {with_fill}" if with_fill else "" @@ -1961,10 +2031,14 @@ class Generator: return [locks, self.sql(expression, "sample")] def select_sql(self, expression: exp.Select) -> str: + into = expression.args.get("into") + if not self.SUPPORTS_SELECT_INTO and into: + into.pop() + hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" - kind = self.sql(expression, "kind").upper() + kind = self.sql(expression, "kind") limit = expression.args.get("limit") top = ( self.limit_sql(limit, top=True) @@ -2005,7 +2079,19 @@ class Generator: self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) - return self.prepend_ctes(expression, sql) + + sql = self.prepend_ctes(expression, sql) + + if not self.SUPPORTS_SELECT_INTO and into: + if into.args.get("temporary"): + table_kind = " TEMPORARY" + elif self.SUPPORTS_UNLOGGED_TABLES and into.args.get("unlogged"): + table_kind = " UNLOGGED" + else: + table_kind = "" + sql = f"CREATE{table_kind} TABLE {self.sql(into.this)} AS {sql}" + + return sql def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") @@ -2266,29 +2352,35 @@ class Generator: return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})" def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: - return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}" + return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}" def formatjson_sql(self, expression: exp.FormatJson) -> str: return f"{self.sql(expression, 'this')} FORMAT JSON" - def jsonobject_sql(self, expression: exp.JSONObject) -> str: + def jsonobject_sql(self, expression: exp.JSONObject | exp.JSONObjectAgg) -> str: null_handling = expression.args.get("null_handling") null_handling = f" {null_handling}" if null_handling else "" + unique_keys = expression.args.get("unique_keys") if unique_keys is not None: unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS" else: unique_keys = "" + return_type = self.sql(expression, "return_type") return_type = f" RETURNING {return_type}" if return_type else "" encoding = self.sql(expression, "encoding") encoding = f" ENCODING {encoding}" if encoding else "" + return self.func( - "JSON_OBJECT", + "JSON_OBJECT" if isinstance(expression, exp.JSONObject) else "JSON_OBJECTAGG", *expression.expressions, suffix=f"{null_handling}{unique_keys}{return_type}{encoding})", ) + def jsonobjectagg_sql(self, expression: exp.JSONObjectAgg) -> str: + return self.jsonobject_sql(expression) + def jsonarray_sql(self, expression: exp.JSONArray) -> str: null_handling = expression.args.get("null_handling") null_handling = f" {null_handling}" if null_handling else "" @@ -2385,7 +2477,7 @@ class Generator: def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") if not self.INTERVAL_ALLOWS_PLURAL_FORM: - unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit) + unit = self.TIME_PART_SINGULARS.get(unit, unit) unit = f" {unit}" if unit else "" if self.SINGLE_STRING_INTERVAL: @@ -2436,9 +2528,25 @@ class Generator: alias = f" AS {alias}" if alias else "" return f"{self.sql(expression, 'this')}{alias}" + def pivotalias_sql(self, expression: exp.PivotAlias) -> str: + alias = expression.args["alias"] + identifier_alias = isinstance(alias, exp.Identifier) + + if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS: + alias.replace(exp.Literal.string(alias.output_name)) + elif not identifier_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS: + alias.replace(exp.to_identifier(alias.output_name)) + + return self.alias_sql(expression) + def aliases_sql(self, expression: exp.Aliases) -> str: return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" + def atindex_sql(self, expression: exp.AtTimeZone) -> str: + this = self.sql(expression, "this") + index = self.sql(expression, "expression") + return f"{this} AT {index}" + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: this = self.sql(expression, "this") zone = self.sql(expression, "zone") @@ -2500,7 +2608,7 @@ class Generator: return self.binary(expression, "COLLATE") def command_sql(self, expression: exp.Command) -> str: - return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" + return f"{self.sql(expression, 'this')} {expression.text('expression').strip()}" def comment_sql(self, expression: exp.Comment) -> str: this = self.sql(expression, "this") @@ -3102,6 +3210,47 @@ class Generator: cond_for_null = arg.is_(exp.null()) return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg]))) + def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str: + this = expression.this + if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME): + return self.sql(this) + + return self.sql(exp.cast(this, "time")) + + def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str: + this = expression.this + time_format = self.format_time(expression) + + if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT): + return self.sql( + exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date") + ) + + if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE): + return self.sql(this) + + return self.sql(exp.cast(this, "date")) + + def unixdate_sql(self, expression: exp.UnixDate) -> str: + return self.sql( + exp.func( + "DATEDIFF", + expression.this, + exp.cast(exp.Literal.string("1970-01-01"), "date"), + "day", + ) + ) + + def lastday_sql(self, expression: exp.LastDay) -> str: + if self.LAST_DAY_SUPPORTS_DATE_PART: + return self.function_fallback_sql(expression) + + unit = expression.text("unit") + if unit and unit != "MONTH": + self.unsupported("Date parts are not supported in LAST_DAY.") + + return self.func("LAST_DAY", expression.this) + def _simplify_unless_literal(self, expression: E) -> E: if not isinstance(expression, exp.Literal): from sqlglot.optimizer.simplify import simplify diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index abcc10f..09bf201 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -129,13 +129,10 @@ def lineage( if isinstance(column, int) else next( (select for select in scope.expression.selects if select.alias_or_name == column), - exp.Star() if scope.expression.is_star else None, + exp.Star() if scope.expression.is_star else scope.expression, ) ) - if not select: - raise ValueError(f"Could not find {column} in {scope.expression}") - if isinstance(scope.expression, exp.Union): upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) @@ -194,6 +191,8 @@ def lineage( # if the select is a star add all scope sources as downstreams if select.is_star: for source in scope.sources.values(): + if isinstance(source, Scope): + source = source.expression node.downstream.append(Node(name=select.sql(), source=source, expression=source)) # Find all columns that went into creating this one to list their lineage nodes. diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 7b990f1..d0168d5 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -195,6 +195,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.StrPosition, exp.TsOrDiToDi, }, + exp.DataType.Type.JSON: { + exp.ParseJSON, + }, exp.DataType.Type.TIMESTAMP: { exp.CurrentTime, exp.CurrentTimestamp, @@ -275,6 +278,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), + exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), } NESTED_TYPES = { @@ -477,7 +481,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator): @t.no_type_check def _annotate_by_args( - self, expression: E, *args: str, promote: bool = False, array: bool = False + self, + expression: E, + *args: str, + promote: bool = False, + array: bool = False, + struct: bool = False, ) -> E: self._annotate_args(expression) @@ -506,6 +515,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ), ) + if struct: + expressions = [ + expr.type + if not expr.args.get("alias") + else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type) + for expr in expressions + ] + + self._set_type( + expression, + exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True), + ) + return expression def _annotate_timeunit( diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 10ff13a..12c3b89 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -30,13 +30,18 @@ def pushdown_predicates(expression, dialect=None): where = select.args.get("where") if where: selected_sources = scope.selected_sources + join_index = { + join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) + } + # a right join can only push down to itself and not the source FROM table for k, (node, source) in selected_sources.items(): parent = node.find_ancestor(exp.Join, exp.From) if isinstance(parent, exp.Join) and parent.side == "RIGHT": selected_sources = {k: (node, source)} break - pushdown(where.this, selected_sources, scope_ref_count, dialect) + + pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -53,7 +58,7 @@ def pushdown_predicates(expression, dialect=None): return expression -def pushdown(condition, sources, scope_ref_count, dialect): +def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): if not condition: return @@ -67,21 +72,28 @@ def pushdown(condition, sources, scope_ref_count, dialect): ) if cnf_like: - pushdown_cnf(predicates, sources, scope_ref_count) + pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) else: pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope, scope_ref_count): +def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ + join_index = join_index or {} for predicate in predicates: for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): - predicate.replace(exp.true()) - node.on(predicate, copy=False) - break + name = node.alias_or_name + predicate_tables = exp.column_table_names(predicate, name) + + # Don't push the predicate if it references tables that appear in later joins + this_index = join_index[name] + if all(join_index.get(table, -1) < this_index for table in predicate_tables): + predicate.replace(exp.true()) + node.on(predicate, copy=False) + break if isinstance(node, exp.Select): predicate.replace(exp.true()) inner_predicate = replace_aliases(node, predicate) @@ -112,9 +124,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count): conditions = {} - # for every pushdown table, find all related conditions in all predicates - # combine them with ORS - # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) + # pushdown all predicates to their respective nodes for table in sorted(pushdown_tables): for predicate in predicates: nodes = nodes_for_predicate(predicate, scope, scope_ref_count) @@ -122,23 +132,9 @@ def pushdown_dnf(predicates, scope, scope_ref_count): if table not in nodes: continue - predicate_condition = None - - for column in predicate.find_all(exp.Column): - if column.table == table: - condition = column.find_ancestor(exp.Condition) - predicate_condition = ( - exp.and_(predicate_condition, condition) - if predicate_condition - else condition - ) - - if predicate_condition: - conditions[table] = ( - exp.or_(conditions[table], predicate_condition) - if table in conditions - else predicate_condition - ) + conditions[table] = ( + exp.or_(conditions[table], predicate) if table in conditions else predicate + ) for name, node in nodes.items(): if name not in conditions: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 4bc3bd2..e3aaebc 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -43,9 +43,8 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) parent_selections = referenced_columns.get(scope, {SELECT_ALL}) alias_count = source_column_alias_count.get(scope, 0) - if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots): - # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if - # we select from a pivoted source in the parent scope. + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. + if scope.expression.args.get("distinct"): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): @@ -78,7 +77,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) # Push the selected columns down to the next scope for name, (node, source) in scope.selected_sources.items(): if isinstance(source, Scope): - columns = selects.get(name) or set() + columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set() referenced_columns[source].update(columns) column_aliases = node.alias_column_names diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 5fdbde8..8d83b47 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -3,10 +3,11 @@ from __future__ import annotations import typing as t from sqlglot import exp -from sqlglot.dialects.dialect import DialectType +from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import ( + pushdown_cte_alias_columns as pushdown_cte_alias_columns_func, qualify_columns as qualify_columns_func, quote_identifiers as quote_identifiers_func, validate_qualify_columns as validate_qualify_columns_func, @@ -22,6 +23,7 @@ def qualify( catalog: t.Optional[str] = None, schema: t.Optional[dict | Schema] = None, expand_alias_refs: bool = True, + expand_stars: bool = True, infer_schema: t.Optional[bool] = None, isolate_tables: bool = False, qualify_columns: bool = True, @@ -47,6 +49,9 @@ def qualify( catalog: Default catalog name for tables. schema: Schema to infer column names and types. expand_alias_refs: Whether or not to expand references to aliases. + expand_stars: Whether or not to expand star queries. This is a necessary step + for most of the optimizer's rules to work; do not set to False unless you + know what you're doing! infer_schema: Whether or not to infer the schema if missing. isolate_tables: Whether or not to isolate table selects. qualify_columns: Whether or not to qualify columns. @@ -66,9 +71,16 @@ def qualify( if isolate_tables: expression = isolate_table_selects(expression, schema=schema) + if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN: + expression = pushdown_cte_alias_columns_func(expression) + if qualify_columns: expression = qualify_columns_func( - expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema + expression, + schema, + expand_alias_refs=expand_alias_refs, + expand_stars=expand_stars, + infer_schema=infer_schema, ) if quote_identifiers: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 742cdf5..a6397ae 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -17,6 +17,7 @@ def qualify_columns( expression: exp.Expression, schema: t.Dict | Schema, expand_alias_refs: bool = True, + expand_stars: bool = True, infer_schema: t.Optional[bool] = None, ) -> exp.Expression: """ @@ -33,10 +34,16 @@ def qualify_columns( expression: Expression to qualify. schema: Database schema. expand_alias_refs: Whether or not to expand references to aliases. + expand_stars: Whether or not to expand star queries. This is a necessary step + for most of the optimizer's rules to work; do not set to False unless you + know what you're doing! infer_schema: Whether or not to infer the schema if missing. Returns: The qualified expression. + + Notes: + - Currently only handles a single PIVOT or UNPIVOT operator """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema @@ -57,7 +64,8 @@ def qualify_columns( _expand_alias_refs(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver, using_column_tables, pseudocolumns) + if expand_stars: + _expand_stars(scope, resolver, using_column_tables, pseudocolumns) qualify_outputs(scope) _expand_group_by(scope) @@ -68,21 +76,41 @@ def qualify_columns( def validate_qualify_columns(expression: E) -> E: """Raise an `OptimizeError` if any columns aren't qualified""" - unqualified_columns = [] + all_unqualified_columns = [] for scope in traverse_scope(expression): if isinstance(scope.expression, exp.Select): - unqualified_columns.extend(scope.unqualified_columns) + unqualified_columns = scope.unqualified_columns + if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: column = scope.external_columns[0] - raise OptimizeError( - f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" - ) + for_table = f" for table: '{column.table}'" if column.table else "" + raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") + + if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: + # New columns produced by the UNPIVOT can't be qualified, but there may be columns + # under the UNPIVOT's IN clause that can and should be qualified. We recompute + # this list here to ensure those in the former category will be excluded. + unpivot_columns = set(_unpivot_columns(scope.pivots[0])) + unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] + + all_unqualified_columns.extend(unqualified_columns) + + if all_unqualified_columns: + raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") - if unqualified_columns: - raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") return expression +def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: + name_column = [] + field = unpivot.args.get("field") + if isinstance(field, exp.In) and isinstance(field.this, exp.Column): + name_column.append(field.this) + + value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) + return itertools.chain(name_column, value_columns) + + def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: """ Remove table column aliases. @@ -216,6 +244,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: replace_columns(expression.args.get("group"), literal_index=True) replace_columns(expression.args.get("having"), resolve_table=True) replace_columns(expression.args.get("qualify"), resolve_table=True) + scope.clear_cache() @@ -353,18 +382,25 @@ def _expand_stars( replace_columns: t.Dict[int, t.Dict[str, str]] = {} coalesced_columns = set() - # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future - pivot_columns = None pivot_output_columns = None - pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) + pivot_exclude_columns = None - has_pivoted_source = pivot and not pivot.args.get("unpivot") - if pivot and has_pivoted_source: - pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) + pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) + if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: + if pivot.unpivot: + pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] + + field = pivot.args.get("field") + if isinstance(field, exp.In): + pivot_exclude_columns = { + c.output_name for e in field.expressions for c in e.find_all(exp.Column) + } + else: + pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) - pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] - if not pivot_output_columns: - pivot_output_columns = [col.alias_or_name for col in pivot.expressions] + pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] + if not pivot_output_columns: + pivot_output_columns = [c.alias_or_name for c in pivot.expressions] for expression in scope.expression.selects: if isinstance(expression, exp.Star): @@ -384,47 +420,54 @@ def _expand_stars( raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) + columns = columns or scope.outer_column_list if pseudocolumns: columns = [name for name in columns if name.upper() not in pseudocolumns] - if columns and "*" not in columns: - table_id = id(table) - columns_to_exclude = except_columns.get(table_id) or set() + if not columns or "*" in columns: + return + + table_id = id(table) + columns_to_exclude = except_columns.get(table_id) or set() - if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: - implicit_columns = [col for col in columns if col not in pivot_columns] + if pivot: + if pivot_output_columns and pivot_exclude_columns: + pivot_columns = [c for c in columns if c not in pivot_exclude_columns] + pivot_columns.extend(pivot_output_columns) + else: + pivot_columns = pivot.alias_column_names + + if pivot_columns: new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) - for name in implicit_columns + pivot_output_columns + for name in pivot_columns if name not in columns_to_exclude ) continue - for name in columns: - if name in using_column_tables and table in using_column_tables[name]: - if name in coalesced_columns: - continue - - coalesced_columns.add(name) - tables = using_column_tables[name] - coalesce = [exp.column(name, table=table) for table in tables] - - new_selections.append( - alias( - exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), - alias=name, - copy=False, - ) - ) - elif name not in columns_to_exclude: - alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table=table) - new_selections.append( - alias(column, alias_, copy=False) if alias_ != name else column + for name in columns: + if name in using_column_tables and table in using_column_tables[name]: + if name in coalesced_columns: + continue + + coalesced_columns.add(name) + tables = using_column_tables[name] + coalesce = [exp.column(name, table=table) for table in tables] + + new_selections.append( + alias( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), + alias=name, + copy=False, ) - else: - return + ) + elif name not in columns_to_exclude: + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table=table) + new_selections.append( + alias(column, alias_, copy=False) if alias_ != name else column + ) # Ensures we don't overwrite the initial selections with an empty list if new_selections: @@ -472,6 +515,9 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): + if selection is None: + break + if isinstance(selection, exp.Subquery): if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) @@ -495,6 +541,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool ) +def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: + """ + Pushes down the CTE alias columns into the projection, + + This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") + >>> pushdown_cte_alias_columns(expression).sql() + 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' + + Args: + expression: Expression to pushdown. + + Returns: + The expression with the CTE aliases pushed down into the projection. + """ + for cte in expression.find_all(exp.CTE): + if cte.alias_column_names: + new_expressions = [] + for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): + if isinstance(projection, exp.Alias): + projection.set("alias", _alias) + else: + projection = alias(projection, alias=_alias) + new_expressions.append(projection) + cte.this.set("expressions", new_expressions) + + return expression + + class Resolver: """ Helper for resolving columns. diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 57ecabe..e0fe641 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -72,11 +72,15 @@ def qualify_tables( if not source.args.get("catalog") and source.args.get("db"): source.set("catalog", catalog) + pivots = pivots = source.args.get("pivots") if not source.alias: + # Don't add the pivot's alias to the pivoted table, use the table's name instead + if pivots and pivots[0].alias == name: + name = source.name + # Mutates the source by attaching an alias to it alias(source, name or source.name or next_alias_name(), copy=False, table=True) - pivots = source.args.get("pivots") if pivots and not pivots[0].alias: pivots[0].set( "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index d34857d..a3f08d5 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -539,11 +539,23 @@ def _traverse_union(scope): # The last scope to be yield should be the top most scope left = None - for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): + for left in _traverse_scope( + scope.branch( + scope.expression.left, + outer_column_list=scope.outer_column_list, + scope_type=ScopeType.UNION, + ) + ): yield left right = None - for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): + for right in _traverse_scope( + scope.branch( + scope.expression.right, + outer_column_list=scope.outer_column_list, + scope_type=ScopeType.UNION, + ) + ): yield right scope.union_scopes = [left, right] diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f53023c..25d4e75 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -100,6 +100,7 @@ def simplify( node = simplify_parens(node) node = simplify_datetrunc(node, dialect) node = sort_comparison(node) + node = simplify_startswith(node) if root: expression.replace(node) @@ -776,6 +777,26 @@ def simplify_conditionals(expression): return expression +def simplify_startswith(expression: exp.Expression) -> exp.Expression: + """ + Reduces a prefix check to either TRUE or FALSE if both the string and the + prefix are statically known. + + Example: + >>> from sqlglot import parse_one + >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() + 'TRUE' + """ + if ( + isinstance(expression, exp.StartsWith) + and expression.this.is_string + and expression.expression.is_string + ): + return exp.convert(expression.name.startswith(expression.expression.name)) + + return expression + + DateRange = t.Tuple[datetime.date, datetime.date] @@ -1160,7 +1181,7 @@ def gen(expression: t.Any) -> str: GEN_MAP = { exp.Add: lambda e: _binary(e, "+"), exp.And: lambda e: _binary(e, "AND"), - exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}", + exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}", exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 311c43d..790ee0d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -12,6 +12,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import TrieResult, in_trie, new_trie if t.TYPE_CHECKING: + from typing_extensions import Literal + from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType @@ -193,6 +195,7 @@ class Parser(metaclass=_Parser): TokenType.DATETIME, TokenType.DATETIME64, TokenType.DATE, + TokenType.DATE32, TokenType.INT4RANGE, TokenType.INT4MULTIRANGE, TokenType.INT8RANGE, @@ -232,6 +235,8 @@ class Parser(metaclass=_Parser): TokenType.INET, TokenType.IPADDRESS, TokenType.IPPREFIX, + TokenType.IPV4, + TokenType.IPV6, TokenType.UNKNOWN, TokenType.NULL, *ENUM_TYPE_TOKENS, @@ -669,6 +674,7 @@ class Parser(metaclass=_Parser): PROPERTY_PARSERS: t.Dict[str, t.Callable] = { "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), + "AUTO": lambda self: self._parse_auto_property(), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), @@ -680,6 +686,7 @@ class Parser(metaclass=_Parser): exp.CollateProperty, **kwargs ), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), + "CONTAINS": lambda self: self._parse_contains_property(), "COPY": lambda self: self._parse_copy_property(), "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), "DEFINER": lambda self: self._parse_definer(), @@ -710,6 +717,7 @@ class Parser(metaclass=_Parser): "LOG": lambda self, **kwargs: self._parse_log(**kwargs), "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs), + "MODIFIES": lambda self: self._parse_modifies_property(), "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), "NO": lambda self: self._parse_no_property(), "ON": lambda self: self._parse_on_property(), @@ -721,6 +729,7 @@ class Parser(metaclass=_Parser): "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True), "RANGE": lambda self: self._parse_dict_range(this="RANGE"), + "READS": lambda self: self._parse_reads_property(), "REMOTE": lambda self: self._parse_remote_with_connection(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), @@ -841,6 +850,7 @@ class Parser(metaclass=_Parser): "DECODE": lambda self: self._parse_decode(), "EXTRACT": lambda self: self._parse_extract(), "JSON_OBJECT": lambda self: self._parse_json_object(), + "JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True), "JSON_TABLE": lambda self: self._parse_json_table(), "MATCH": lambda self: self._parse_match_against(), "OPENJSON": lambda self: self._parse_open_json(), @@ -925,6 +935,8 @@ class Parser(metaclass=_Parser): WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} + JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS} + FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -954,6 +966,9 @@ class Parser(metaclass=_Parser): # Whether the TRIM function expects the characters to trim as its first argument TRIM_PATTERN_FIRST = False + # Whether or not string aliases are supported `SELECT COUNT(*) 'count'` + STRING_ALIASES = False + # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand) MODIFIERS_ATTACHED_TO_UNION = True UNION_MODIFIERS = {"order", "limit", "offset"} @@ -1193,7 +1208,9 @@ class Parser(metaclass=_Parser): self._advance(index - self._index) def _parse_command(self) -> exp.Command: - return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) + return self.expression( + exp.Command, this=self._prev.text.upper(), expression=self._parse_string() + ) def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: start = self._prev @@ -1353,26 +1370,27 @@ class Parser(metaclass=_Parser): # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature) extend_props(self._parse_properties()) - self._match(TokenType.ALIAS) - - if self._match(TokenType.COMMAND): - expression = self._parse_as_command(self._prev) - else: - begin = self._match(TokenType.BEGIN) - return_ = self._match_text_seq("RETURN") + expression = self._match(TokenType.ALIAS) and self._parse_heredoc() - if self._match(TokenType.STRING, advance=False): - # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property - # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement - expression = self._parse_string() - extend_props(self._parse_properties()) + if not expression: + if self._match(TokenType.COMMAND): + expression = self._parse_as_command(self._prev) else: - expression = self._parse_statement() + begin = self._match(TokenType.BEGIN) + return_ = self._match_text_seq("RETURN") + + if self._match(TokenType.STRING, advance=False): + # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property + # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement + expression = self._parse_string() + extend_props(self._parse_properties()) + else: + expression = self._parse_statement() - end = self._match_text_seq("END") + end = self._match_text_seq("END") - if return_: - expression = self.expression(exp.Return, this=expression) + if return_: + expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: @@ -1426,7 +1444,7 @@ class Parser(metaclass=_Parser): exp.Create, comments=comments, this=this, - kind=create_token.text, + kind=create_token.text.upper(), replace=replace, unique=unique, expression=expression, @@ -1849,9 +1867,21 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]: + def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL"): + return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL") + return None + + def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL", "DATA"): + return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA") + return None + + def _parse_no_property(self) -> t.Optional[exp.Expression]: if self._match_text_seq("PRIMARY", "INDEX"): return exp.NoPrimaryIndexProperty() + if self._match_text_seq("SQL"): + return self.expression(exp.SqlReadWriteProperty, this="NO SQL") return None def _parse_on_property(self) -> t.Optional[exp.Expression]: @@ -1861,6 +1891,11 @@ class Parser(metaclass=_Parser): return exp.OnCommitProperty(delete=True) return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var())) + def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]: + if self._match_text_seq("SQL", "DATA"): + return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA") + return None + def _parse_distkey(self) -> exp.DistKeyProperty: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) @@ -1920,10 +1955,13 @@ class Parser(metaclass=_Parser): def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text + extended = self._match_text_seq("EXTENDED") this = self._parse_table(schema=True) properties = self._parse_properties() expressions = properties.expressions if properties else None - return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions) + return self.expression( + exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions + ) def _parse_insert(self) -> exp.Insert: comments = ensure_list(self._prev_comments) @@ -2164,13 +2202,13 @@ class Parser(metaclass=_Parser): def _parse_value(self) -> exp.Tuple: if self._match(TokenType.L_PAREN): - expressions = self._parse_csv(self._parse_conjunction) + expressions = self._parse_csv(self._parse_expression) self._match_r_paren() return self.expression(exp.Tuple, expressions=expressions) # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. # https://prestodb.io/docs/current/sql/values.html - return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) + return self.expression(exp.Tuple, expressions=[self._parse_expression()]) def _parse_projections(self) -> t.List[exp.Expression]: return self._parse_expressions() @@ -2212,7 +2250,7 @@ class Parser(metaclass=_Parser): kind = ( self._match(TokenType.ALIAS) and self._match_texts(("STRUCT", "VALUE")) - and self._prev.text + and self._prev.text.upper() ) if distinct: @@ -2261,7 +2299,7 @@ class Parser(metaclass=_Parser): if table else self._parse_select(nested=True, parse_set_operation=False) ) - this = self._parse_set_operations(self._parse_query_modifiers(this)) + this = self._parse_query_modifiers(self._parse_set_operations(this)) self._match_r_paren() @@ -2304,7 +2342,7 @@ class Parser(metaclass=_Parser): ) def _parse_cte(self) -> exp.CTE: - alias = self._parse_table_alias() + alias = self._parse_table_alias(self.ID_VAR_TOKENS) if not alias or not alias.this: self.raise_error("Expected CTE to have alias") @@ -2490,13 +2528,14 @@ class Parser(metaclass=_Parser): ) def _parse_lateral(self) -> t.Optional[exp.Lateral]: - outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) + if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY): + cross_apply = False - if outer_apply or cross_apply: + if cross_apply is not None: this = self._parse_select(table=True) view = None - outer = not cross_apply + outer = None elif self._match(TokenType.LATERAL): this = self._parse_select(table=True) view = self._match(TokenType.VIEW) @@ -2529,7 +2568,14 @@ class Parser(metaclass=_Parser): else: table_alias = self._parse_table_alias() - return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias) + return self.expression( + exp.Lateral, + this=this, + view=view, + outer=outer, + alias=table_alias, + cross_apply=cross_apply, + ) def _parse_join_parts( self, @@ -2563,9 +2609,6 @@ class Parser(metaclass=_Parser): if not skip_join_token and not join and not outer_apply and not cross_apply: return None - if outer_apply: - side = Token(TokenType.LEFT, "LEFT") - kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)} if method: @@ -2755,8 +2798,10 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) - if self._match_text_seq("AT"): - this.set("index", self._parse_id_var()) + if isinstance(this, exp.Table) and self._match_text_seq("AT"): + return self.expression( + exp.AtIndex, this=this.to_column(copy=False), expression=self._parse_id_var() + ) this.set("hints", self._parse_table_hints()) @@ -2865,15 +2910,10 @@ class Parser(metaclass=_Parser): bucket_denominator = None bucket_field = None percent = None - rows = None size = None seed = None - kind = ( - self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE" - ) - method = self._parse_var(tokens=(TokenType.ROW,)) - + method = self._parse_var(tokens=(TokenType.ROW,), upper=True) matched_l_paren = self._match(TokenType.L_PAREN) if self.TABLESAMPLE_CSV: @@ -2895,16 +2935,16 @@ class Parser(metaclass=_Parser): bucket_field = self._parse_field() elif self._match_set((TokenType.PERCENT, TokenType.MOD)): percent = num - elif self._match(TokenType.ROWS): - rows = num - elif num: + elif self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT: size = num + else: + percent = num if matched_l_paren: self._match_r_paren() if self._match(TokenType.L_PAREN): - method = self._parse_var() + method = self._parse_var(upper=True) seed = self._match(TokenType.COMMA) and self._parse_number() self._match_r_paren() elif self._match_texts(("SEED", "REPEATABLE")): @@ -2918,10 +2958,8 @@ class Parser(metaclass=_Parser): bucket_denominator=bucket_denominator, bucket_field=bucket_field, percent=percent, - rows=rows, size=size, seed=seed, - kind=kind, ) def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: @@ -2946,6 +2984,27 @@ class Parser(metaclass=_Parser): exp.Pivot, this=this, expressions=expressions, using=using, group=group ) + def _parse_pivot_in(self) -> exp.In: + def _parse_aliased_expression() -> t.Optional[exp.Expression]: + this = self._parse_conjunction() + + self._match(TokenType.ALIAS) + alias = self._parse_field() + if alias: + return self.expression(exp.PivotAlias, this=this, alias=alias) + + return this + + value = self._parse_column() + + if not self._match_pair(TokenType.IN, TokenType.L_PAREN): + self.raise_error("Expecting IN (") + + aliased_expressions = self._parse_csv(_parse_aliased_expression) + + self._match_r_paren() + return self.expression(exp.In, this=value, expressions=aliased_expressions) + def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index include_nulls = None @@ -2964,7 +3023,6 @@ class Parser(metaclass=_Parser): return None expressions = [] - field = None if not self._match(TokenType.L_PAREN): self._retreat(index) @@ -2981,12 +3039,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.FOR): self.raise_error("Expecting FOR") - value = self._parse_column() - - if not self._match(TokenType.IN): - self.raise_error("Expecting IN") - - field = self._parse_in(value, alias=True) + field = self._parse_pivot_in() self._match_r_paren() @@ -3132,14 +3185,19 @@ class Parser(metaclass=_Parser): def _parse_order( self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False ) -> t.Optional[exp.Expression]: + siblings = None if not skip_order_token and not self._match(TokenType.ORDER_BY): - return this + if not self._match(TokenType.ORDER_SIBLINGS_BY): + return this + + siblings = True return self.expression( exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered), interpolate=self._parse_interpolate(), + siblings=siblings, ) def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]: @@ -3213,7 +3271,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) - direction = self._prev.text if direction else "FIRST" + direction = self._prev.text.upper() if direction else "FIRST" count = self._parse_field(tokens=self.FETCH_TOKENS) percent = self._match(TokenType.PERCENT) @@ -3398,10 +3456,10 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_interval(self) -> t.Optional[exp.Interval]: + def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]: index = self._index - if not self._match(TokenType.INTERVAL): + if not self._match(TokenType.INTERVAL) and match_interval: return None if self._match(TokenType.STRING, advance=False): @@ -3409,11 +3467,19 @@ class Parser(metaclass=_Parser): else: this = self._parse_term() - if not this: + if not this or ( + isinstance(this, exp.Column) + and not this.table + and not this.this.quoted + and this.name.upper() == "IS" + ): self._retreat(index) return None - unit = self._parse_function() or self._parse_var(any_token=True) + unit = self._parse_function() or ( + not self._match(TokenType.ALIAS, advance=False) + and self._parse_var(any_token=True, upper=True) + ) # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse # each INTERVAL expression into this canonical form so it's easy to transpile @@ -3429,7 +3495,7 @@ class Parser(metaclass=_Parser): self._retreat(self._index - 1) this = exp.Literal.string(parts[0]) - unit = self.expression(exp.Var, this=parts[1]) + unit = self.expression(exp.Var, this=parts[1].upper()) return self.expression(exp.Interval, this=this, unit=unit) @@ -3489,6 +3555,12 @@ class Parser(metaclass=_Parser): def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]: interval = parse_interval and self._parse_interval() if interval: + # Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals + while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): + interval = self.expression( # type: ignore + exp.Add, this=interval, expression=self._parse_interval(match_interval=False) + ) + return interval index = self._index @@ -3552,10 +3624,10 @@ class Parser(metaclass=_Parser): type_token = self._prev.token_type if type_token == TokenType.PSEUDO_TYPE: - return self.expression(exp.PseudoType, this=self._prev.text) + return self.expression(exp.PseudoType, this=self._prev.text.upper()) if type_token == TokenType.OBJECT_IDENTIFIER: - return self.expression(exp.ObjectIdentifier, this=self._prev.text) + return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper()) nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token in self.STRUCT_TYPE_TOKENS @@ -3587,7 +3659,7 @@ class Parser(metaclass=_Parser): if nested and self._match(TokenType.LT): if is_struct: - expressions = self._parse_csv(self._parse_struct_types) + expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True)) else: expressions = self._parse_csv( lambda: self._parse_types( @@ -3662,10 +3734,19 @@ class Parser(metaclass=_Parser): return this - def _parse_struct_types(self) -> t.Optional[exp.Expression]: + def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: + index = self._index this = self._parse_type(parse_interval=False) or self._parse_id_var() self._match(TokenType.COLON) - return self._parse_column_def(this) + column_def = self._parse_column_def(this) + + if type_required and ( + (isinstance(this, exp.Column) and this.this is column_def) or this is column_def + ): + self._retreat(index) + return self._parse_types() + + return column_def def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_text_seq("AT", "TIME", "ZONE"): @@ -4025,6 +4106,12 @@ class Parser(metaclass=_Parser): return exp.AutoIncrementColumnConstraint() + def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]: + if not self._match_text_seq("REFRESH"): + self._retreat(self._index - 1) + return None + return self.expression(exp.AutoRefreshProperty, this=self._parse_var(upper=True)) + def _parse_compress(self) -> exp.CompressColumnConstraint: if self._match(TokenType.L_PAREN, advance=False): return self.expression( @@ -4230,8 +4317,10 @@ class Parser(metaclass=_Parser): def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: return self._parse_field() - def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint: - self._match(TokenType.TIMESTAMP_SNAPSHOT) + def _parse_period_for_system_time(self) -> t.Optional[exp.PeriodForSystemTimeConstraint]: + if not self._match(TokenType.TIMESTAMP_SNAPSHOT): + self._retreat(self._index - 1) + return None id_vars = self._parse_wrapped_id_vars() return self.expression( @@ -4257,22 +4346,17 @@ class Parser(metaclass=_Parser): options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: + return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True)) + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this bracket_kind = self._prev.token_type - - if self._match(TokenType.COLON): - expressions: t.List[exp.Expression] = [ - self.expression(exp.Slice, expression=self._parse_conjunction()) - ] - else: - expressions = self._parse_csv( - lambda: self._parse_slice( - self._parse_alias(self._parse_conjunction(), explicit=True) - ) - ) + expressions = self._parse_csv( + lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE) + ) if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: self.raise_error("Expected ]") @@ -4313,7 +4397,10 @@ class Parser(metaclass=_Parser): default = self._parse_conjunction() if not self._match(TokenType.END): - self.raise_error("Expected END after CASE", self._prev) + if isinstance(default, exp.Interval) and default.this.sql().upper() == "END": + default = exp.column("interval") + else: + self.raise_error("Expected END after CASE", self._prev) return self._parse_window( self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default) @@ -4514,7 +4601,7 @@ class Parser(metaclass=_Parser): def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: self._match_text_seq("KEY") key = self._parse_column() - self._match_set((TokenType.COLON, TokenType.COMMA)) + self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS) self._match_text_seq("VALUE") value = self._parse_bitwise() @@ -4536,7 +4623,15 @@ class Parser(metaclass=_Parser): return None - def _parse_json_object(self) -> exp.JSONObject: + @t.overload + def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: + ... + + @t.overload + def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: + ... + + def _parse_json_object(self, agg=False): star = self._parse_star() expressions = ( [star] @@ -4559,7 +4654,7 @@ class Parser(metaclass=_Parser): encoding = self._match_text_seq("ENCODING") and self._parse_var() return self.expression( - exp.JSONObject, + exp.JSONObjectAgg if agg else exp.JSONObject, expressions=expressions, null_handling=null_handling, unique_keys=unique_keys, @@ -4873,10 +4968,17 @@ class Parser(metaclass=_Parser): self._match_r_paren(aliases) return aliases - alias = self._parse_id_var(any_token) + alias = self._parse_id_var(any_token) or ( + self.STRING_ALIASES and self._parse_string_as_identifier() + ) if alias: - return self.expression(exp.Alias, comments=comments, this=this, alias=alias) + this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) + + # Moves the comment next to the alias in `expr /* comment */ AS alias` + if not this.comments and this.this.comments: + this.comments = this.this.comments + this.this.comments = None return this @@ -4915,14 +5017,19 @@ class Parser(metaclass=_Parser): return self._parse_placeholder() def _parse_var( - self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None + self, + any_token: bool = False, + tokens: t.Optional[t.Collection[TokenType]] = None, + upper: bool = False, ) -> t.Optional[exp.Expression]: if ( (any_token and self._advance_any()) or self._match(TokenType.VAR) or (self._match_set(tokens) if tokens else False) ): - return self.expression(exp.Var, this=self._prev.text) + return self.expression( + exp.Var, this=self._prev.text.upper() if upper else self._prev.text + ) return self._parse_placeholder() def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]: @@ -5418,6 +5525,42 @@ class Parser(metaclass=_Parser): condition=condition, ) + def _parse_heredoc(self) -> t.Optional[exp.Heredoc]: + if self._match(TokenType.HEREDOC_STRING): + return self.expression(exp.Heredoc, this=self._prev.text) + + if not self._match_text_seq("$"): + return None + + tags = ["$"] + tag_text = None + + if self._is_connected(): + self._advance() + tags.append(self._prev.text.upper()) + else: + self.raise_error("No closing $ found") + + if tags[-1] != "$": + if self._is_connected() and self._match_text_seq("$"): + tag_text = tags[-1] + tags.append("$") + else: + self.raise_error("No closing $ found") + + heredoc_start = self._curr + + while self._curr: + if self._match_text_seq(*tags, advance=False): + this = self._find_sql(heredoc_start, self._prev) + self._advance(len(tags)) + return self.expression(exp.Heredoc, this=this, tag=tag_text) + + self._advance() + + self.raise_error(f"No closing {''.join(tags)} found") + return None + def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 54c08dd..8acd89f 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -215,12 +215,13 @@ class MappingSchema(AbstractMappingSchema, Schema): normalize: bool = True, ) -> None: self.dialect = dialect - self.visible = visible or {} + self.visible = {} if visible is None else visible self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} self._depth = 0 + schema = {} if schema is None else schema - super().__init__(self._normalize(schema or {})) + super().__init__(self._normalize(schema) if self.normalize else schema) @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index de9d4c4..d8fb98b 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -147,6 +147,7 @@ class TokenType(AutoName): DATETIME = auto() DATETIME64 = auto() DATE = auto() + DATE32 = auto() INT4RANGE = auto() INT4MULTIRANGE = auto() INT8RANGE = auto() @@ -182,6 +183,8 @@ class TokenType(AutoName): INET = auto() IPADDRESS = auto() IPPREFIX = auto() + IPV4 = auto() + IPV6 = auto() ENUM = auto() ENUM8 = auto() ENUM16 = auto() @@ -296,6 +299,7 @@ class TokenType(AutoName): ON = auto() OPERATOR = auto() ORDER_BY = auto() + ORDER_SIBLINGS_BY = auto() ORDERED = auto() ORDINALITY = auto() OUTER = auto() diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 03acc2b..0da65b5 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -255,7 +255,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp if not arrays: if expression.args.get("from"): - expression.join(series, copy=False) + expression.join(series, copy=False, join_type="CROSS") else: expression.from_(series, copy=False) |