From 38e6461a8afbd7cb83709ddb998f03d40ba87755 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 23 Jan 2024 06:06:14 +0100 Subject: Merging upstream version 20.9.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 91 +++++++++++--- sqlglot/dialects/clickhouse.py | 201 +++++++++++++++++++++++++++++-- sqlglot/dialects/databricks.py | 2 + sqlglot/dialects/dialect.py | 119 ++++++++++++++---- sqlglot/dialects/doris.py | 9 +- sqlglot/dialects/drill.py | 3 +- sqlglot/dialects/duckdb.py | 121 +++++++++++++++---- sqlglot/dialects/hive.py | 3 +- sqlglot/dialects/mysql.py | 34 +++--- sqlglot/dialects/oracle.py | 20 +++- sqlglot/dialects/postgres.py | 49 +++----- sqlglot/dialects/presto.py | 40 +++---- sqlglot/dialects/redshift.py | 30 ++++- sqlglot/dialects/snowflake.py | 266 +++++++++++++++++++++++++++++++++-------- sqlglot/dialects/spark2.py | 14 +-- sqlglot/dialects/sqlite.py | 1 + sqlglot/dialects/teradata.py | 7 +- sqlglot/dialects/trino.py | 2 + sqlglot/dialects/tsql.py | 128 +++++++++++++++----- 19 files changed, 897 insertions(+), 243 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 7a573e7..0151e6c 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -16,20 +16,22 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, inline_array_sql, - json_keyvalue_comma_sql, max_or_greatest, min_or_least, no_ilike_sql, parse_date_delta_with_interval, + path_to_jsonpath, regexp_replace_sql, rename_func, timestrtotime_sql, ts_or_ds_add_cast, - ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from typing_extensions import Literal + logger = logging.getLogger("sqlglot") @@ -206,12 +208,17 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s return f"TIMESTAMP_MILLIS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"TIMESTAMP_MICROS({timestamp})" - if scale == exp.UnixToTime.NANOS: - # We need to cast to INT64 because that's what BQ expects - return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))" + + +def _parse_time(args: t.List) -> exp.Func: + if len(args) == 1: + return exp.TsOrDsToTime(this=args[0]) + if len(args) == 3: + return exp.TimeFromParts.from_arg_list(args) + + return exp.Anonymous(this="TIME", expressions=args) class BigQuery(Dialect): @@ -329,7 +336,13 @@ class BigQuery(Dialect): "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), "DIV": binary_from_function(exp.IntDiv), + "FORMAT_DATE": lambda args: exp.TimeToStr( + this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0) + ), "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, + "JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar( + this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$") + ), "MD5": exp.MD5Digest.from_arg_list, "TO_HEX": _parse_to_hex, "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( @@ -351,6 +364,7 @@ class BigQuery(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(","), ), + "TIME": _parse_time, "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), @@ -361,9 +375,7 @@ class BigQuery(Dialect): "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), - "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime( - this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS - ), + "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)), "TO_JSON_STRING": exp.JSONFormat.from_arg_list, } @@ -460,7 +472,15 @@ class BigQuery(Dialect): return table - def _parse_json_object(self) -> exp.JSONObject: + @t.overload + def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: + ... + + @t.overload + def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: + ... + + def _parse_json_object(self, agg=False): json_object = super()._parse_json_object() array_kv_pair = seq_get(json_object.expressions, 0) @@ -513,6 +533,10 @@ class BigQuery(Dialect): UNNEST_WITH_ORDINALITY = False COLLATE_IS_FUNC = True LIMIT_ONLY_LITERALS = True + SUPPORTS_TABLE_ALIAS_COLUMNS = False + UNPIVOT_ALIASES_ARE_IDENTIFIERS = False + JSON_KEY_VALUE_PAIR_SEP = "," + NULL_ORDERING_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -525,6 +549,7 @@ class BigQuery(Dialect): exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}" if e.args.get("default") else f"COLLATE {self.sql(e, 'this')}", + exp.CountIf: rename_func("COUNTIF"), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), @@ -536,13 +561,13 @@ class BigQuery(Dialect): exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), + exp.GetPath: path_to_jsonpath(), exp.GroupConcat: rename_func("STRING_AGG"), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql(false_value="NULL"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), - exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), @@ -578,16 +603,17 @@ class BigQuery(Dialect): "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") ), exp.TimeAdd: date_add_interval_sql("TIME", "ADD"), + exp.TimeFromParts: rename_func("TIME"), exp.TimeSub: date_add_interval_sql("TIME", "SUB"), exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: _ts_or_ds_diff_sql, - exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), + exp.TsOrDsToTime: rename_func("TIME"), exp.Unhex: rename_func("FROM_HEX"), + exp.UnixDate: rename_func("UNIX_DATE"), exp.UnixToTime: _unix_to_time_sql, exp.Values: _derived_table_values_to_unnest, exp.VariancePop: rename_func("VAR_POP"), @@ -724,6 +750,26 @@ class BigQuery(Dialect): "within", } + def timetostr_sql(self, expression: exp.TimeToStr) -> str: + if isinstance(expression.this, exp.TsOrDsToDate): + this: exp.Expression = expression.this + else: + this = expression + + return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})" + + def struct_sql(self, expression: exp.Struct) -> str: + args = [] + for expr in expression.expressions: + if isinstance(expr, self.KEY_VALUE_DEFINITIONS): + arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}" + else: + arg = self.sql(expr) + + args.append(arg) + + return self.func("STRUCT", *args) + def eq_sql(self, expression: exp.EQ) -> str: # Operands of = cannot be NULL in BigQuery if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): @@ -760,7 +806,20 @@ class BigQuery(Dialect): return inline_array_sql(self, expression) def bracket_sql(self, expression: exp.Bracket) -> str: + this = self.sql(expression, "this") expressions = expression.expressions + + if len(expressions) == 1: + arg = expressions[0] + if arg.type is None: + from sqlglot.optimizer.annotate_types import annotate_types + + arg = annotate_types(arg) + + if arg.type and arg.type.this in exp.DataType.TEXT_TYPES: + # BQ doesn't support bracket syntax with string values + return f"{this}.{arg.name}" + expressions_sql = ", ".join(self.sql(e) for e in expressions) offset = expression.args.get("offset") @@ -768,13 +827,13 @@ class BigQuery(Dialect): expressions_sql = f"OFFSET({expressions_sql})" elif offset == 1: expressions_sql = f"ORDINAL({expressions_sql})" - else: + elif offset is not None: self.unsupported(f"Unsupported array offset: {offset}") if expression.args.get("safe"): expressions_sql = f"SAFE_{expressions_sql}" - return f"{self.sql(expression, 'this')}[{expressions_sql}]" + return f"{this}[{expressions_sql}]" def transaction_sql(self, *_) -> str: return "BEGIN TRANSACTION" diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 870f402..f2e4fe1 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arg_max_or_min_no_count, + date_delta_sql, inline_array_sql, no_pivot_sql, rename_func, @@ -22,16 +23,25 @@ def _lower_func(sql: str) -> str: return sql[:index].lower() + sql[index:] -def _quantile_sql(self, e): +def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str: quantile = e.args["quantile"] args = f"({self.sql(e, 'this')})" + if isinstance(quantile, exp.Array): func = self.func("quantiles", *quantile) else: func = self.func("quantile", quantile) + return func + args +def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc: + if len(args) == 1: + return exp.CountIf(this=seq_get(args, 0)) + + return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If")) + + class ClickHouse(Dialect): NORMALIZE_FUNCTIONS: bool | str = False NULL_ORDERING = "nulls_are_last" @@ -53,6 +63,7 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ATTACH": TokenType.COMMAND, + "DATE32": TokenType.DATE32, "DATETIME64": TokenType.DATETIME64, "DICTIONARY": TokenType.DICTIONARY, "ENUM": TokenType.ENUM, @@ -75,6 +86,8 @@ class ClickHouse(Dialect): "UINT32": TokenType.UINT, "UINT64": TokenType.UBIGINT, "UINT8": TokenType.UTINYINT, + "IPV4": TokenType.IPV4, + "IPV6": TokenType.IPV6, } SINGLE_TOKENS = { @@ -91,6 +104,8 @@ class ClickHouse(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, + "ARRAYSUM": exp.ArraySum.from_arg_list, + "COUNTIF": _parse_count_if, "DATE_ADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -110,6 +125,138 @@ class ClickHouse(Dialect): "XOR": lambda args: exp.Xor(expressions=args), } + AGG_FUNCTIONS = { + "count", + "min", + "max", + "sum", + "avg", + "any", + "stddevPop", + "stddevSamp", + "varPop", + "varSamp", + "corr", + "covarPop", + "covarSamp", + "entropy", + "exponentialMovingAverage", + "intervalLengthSum", + "kolmogorovSmirnovTest", + "mannWhitneyUTest", + "median", + "rankCorr", + "sumKahan", + "studentTTest", + "welchTTest", + "anyHeavy", + "anyLast", + "boundingRatio", + "first_value", + "last_value", + "argMin", + "argMax", + "avgWeighted", + "topK", + "topKWeighted", + "deltaSum", + "deltaSumTimestamp", + "groupArray", + "groupArrayLast", + "groupUniqArray", + "groupArrayInsertAt", + "groupArrayMovingAvg", + "groupArrayMovingSum", + "groupArraySample", + "groupBitAnd", + "groupBitOr", + "groupBitXor", + "groupBitmap", + "groupBitmapAnd", + "groupBitmapOr", + "groupBitmapXor", + "sumWithOverflow", + "sumMap", + "minMap", + "maxMap", + "skewSamp", + "skewPop", + "kurtSamp", + "kurtPop", + "uniq", + "uniqExact", + "uniqCombined", + "uniqCombined64", + "uniqHLL12", + "uniqTheta", + "quantile", + "quantiles", + "quantileExact", + "quantilesExact", + "quantileExactLow", + "quantilesExactLow", + "quantileExactHigh", + "quantilesExactHigh", + "quantileExactWeighted", + "quantilesExactWeighted", + "quantileTiming", + "quantilesTiming", + "quantileTimingWeighted", + "quantilesTimingWeighted", + "quantileDeterministic", + "quantilesDeterministic", + "quantileTDigest", + "quantilesTDigest", + "quantileTDigestWeighted", + "quantilesTDigestWeighted", + "quantileBFloat16", + "quantilesBFloat16", + "quantileBFloat16Weighted", + "quantilesBFloat16Weighted", + "simpleLinearRegression", + "stochasticLinearRegression", + "stochasticLogisticRegression", + "categoricalInformationValue", + "contingency", + "cramersV", + "cramersVBiasCorrected", + "theilsU", + "maxIntersections", + "maxIntersectionsPosition", + "meanZTest", + "quantileInterpolatedWeighted", + "quantilesInterpolatedWeighted", + "quantileGK", + "quantilesGK", + "sparkBar", + "sumCount", + "largestTriangleThreeBuckets", + } + + AGG_FUNCTIONS_SUFFIXES = [ + "If", + "Array", + "ArrayIf", + "Map", + "SimpleState", + "State", + "Merge", + "MergeState", + "ForEach", + "Distinct", + "OrDefault", + "OrNull", + "Resample", + "ArgMin", + "ArgMax", + ] + + AGG_FUNC_MAPPING = ( + lambda functions, suffixes: { + f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions + } + )(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES) + FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} FUNCTION_PARSERS = { @@ -272,9 +419,18 @@ class ClickHouse(Dialect): ) if isinstance(func, exp.Anonymous): + parts = self.AGG_FUNC_MAPPING.get(func.this) params = self._parse_func_params(func) if params: + if parts and parts[1]: + return self.expression( + exp.CombinedParameterizedAgg, + this=func.this, + expressions=func.expressions, + params=params, + parts=parts, + ) return self.expression( exp.ParameterizedAgg, this=func.this, @@ -282,6 +438,20 @@ class ClickHouse(Dialect): params=params, ) + if parts: + if parts[1]: + return self.expression( + exp.CombinedAggFunc, + this=func.this, + expressions=func.expressions, + parts=parts, + ) + return self.expression( + exp.AnonymousAggFunc, + this=func.this, + expressions=func.expressions, + ) + return func def _parse_func_params( @@ -329,6 +499,9 @@ class ClickHouse(Dialect): STRUCT_DELIMITER = ("(", ")") NVL2_SUPPORTED = False TABLESAMPLE_REQUIRES_PARENS = False + TABLESAMPLE_SIZE_IS_ROWS = False + TABLESAMPLE_KEYWORDS = "SAMPLE" + LAST_DAY_SUPPORTS_DATE_PART = False STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -348,6 +521,7 @@ class ClickHouse(Dialect): **STRING_TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.DATE32: "Date32", exp.DataType.Type.DATETIME64: "DateTime64", exp.DataType.Type.DOUBLE: "Float64", exp.DataType.Type.ENUM: "Enum", @@ -372,24 +546,23 @@ class ClickHouse(Dialect): exp.DataType.Type.UINT256: "UInt256", exp.DataType.Type.USMALLINT: "UInt16", exp.DataType.Type.UTINYINT: "UInt8", + exp.DataType.Type.IPV4: "IPv4", + exp.DataType.Type.IPV6: "IPv6", } TRANSFORMS = { **generator.Generator.TRANSFORMS, - exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.AnyValue: rename_func("any"), exp.ApproxDistinct: rename_func("uniq"), + exp.ArraySum: rename_func("arraySum"), exp.ArgMax: arg_max_or_min_no_count("argMax"), exp.ArgMin: arg_max_or_min_no_count("argMin"), exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), + exp.CountIf: rename_func("countIf"), exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), - exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this - ), - exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this - ), + exp.DateAdd: date_delta_sql("DATE_ADD"), + exp.DateDiff: date_delta_sql("DATE_DIFF"), exp.Explode: rename_func("arrayJoin"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.IsNan: rename_func("isNaN"), @@ -400,6 +573,7 @@ class ClickHouse(Dialect): exp.Quantile: _quantile_sql, exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.Rand: rename_func("randCanonical"), + exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.StartsWith: rename_func("startsWith"), exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), @@ -485,10 +659,19 @@ class ClickHouse(Dialect): else "", ] - def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: + def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str: params = self.expressions(expression, key="params", flat=True) return self.func(expression.name, *expression.expressions) + f"({params})" + def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str: + return self.func(expression.name, *expression.expressions) + + def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str: + return self.anonymousaggfunc_sql(expression) + + def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str: + return self.parameterizedagg_sql(expression) + def placeholder_sql(self, expression: exp.Placeholder) -> str: return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 1c10a8b..8e55b6a 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -30,6 +30,8 @@ class Databricks(Spark): } class Generator(Spark.Generator): + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + TRANSFORMS = { **Spark.Generator.TRANSFORMS, exp.DateAdd: date_delta_sql("DATEADD"), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b7eef45..7664c40 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect): ALIAS_POST_TABLESAMPLE = False """Determines whether or not the table alias comes after tablesample.""" + TABLESAMPLE_SIZE_IS_PERCENT = False + """Determines whether or not a size in the table sample clause represents percentage.""" + NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE """Specifies the strategy according to which identifiers should be normalized.""" @@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect): For example, such columns may be excluded from `SELECT *` queries. """ + PREFER_CTE_ALIAS_COLUMN = False + """ + Some dialects, such as Snowflake, allow you to reference a CTE column alias in the + HAVING clause of the CTE. This flag will cause the CTE alias columns to override + any projection aliases in the subquery. + + For example, + WITH y(c) AS ( + SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 + ) SELECT c FROM y; + + will be rewritten as + + WITH y(c) AS ( + SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 + ) SELECT c FROM y; + """ + # --- Autofilled --- tokenizer_class = Tokenizer @@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect): result = cls.get(dialect_name.strip()) if not result: - raise ValueError(f"Unknown dialect '{dialect_name}'.") + from difflib import get_close_matches + + similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" + if similar: + similar = f" Did you mean {similar}?" + + raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") return result(**kwargs) @@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: n = self.sql(expression, "this") d = self.sql(expression, "expression") - return f"IF({d} <> 0, {n} / {d}, NULL)" + return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: @@ -695,7 +722,7 @@ def date_add_interval_sql( def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: return self.func( - "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this + "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this ) @@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: return self.func("STRPTIME", expression.this, self.format_time(expression)) -def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: - def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str: - _dialect = Dialect.get_or_raise(dialect) - time_format = self.format_time(expression) - if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT): - return self.sql( - exp.cast( - exp.StrToTime(this=expression.this, format=expression.args["format"]), - "date", - ) - ) - return self.sql(exp.cast(expression.this, "date")) - - return _ts_or_ds_to_date_sql - - def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) @@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" -# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon -def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str: - return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" - - def is_parse_json(expression: exp.Expression) -> bool: return isinstance(expression, exp.ParseJSON) or ( isinstance(expression, exp.Cast) and expression.is_type("json") @@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE expression = ts_or_ds_add_cast(expression) return self.func( - name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this + name, + exp.var(expression.text("unit").upper() or "DAY"), + expression.expression, + expression.this, ) return _delta_sql + + +def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath: + from sqlglot.optimizer.simplify import simplify + + # Makes sure the path will be evaluated correctly at runtime to include the path root. + # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`. + path = expression.expression + path = exp.func( + "if", + exp.func("startswith", path, "'['"), + exp.func("concat", "'$'", path), + exp.func("concat", "'$.'", path), + ) + + expression.expression.replace(simplify(path)) + return expression + + +def path_to_jsonpath( + name: str = "JSON_EXTRACT", +) -> t.Callable[[Generator, exp.GetPath], str]: + def _transform(self: Generator, expression: exp.GetPath) -> str: + return rename_func(name)(self, prepend_dollar_to_path(expression)) + + return _transform + + +def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: + trunc_curr_date = exp.func("date_trunc", "month", expression.this) + plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") + minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") + + return self.sql(exp.cast(minus_one_day, "date")) + + +def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: + """Remove table refs from columns in when statements.""" + alias = expression.this.args.get("alias") + + normalize = ( + lambda identifier: self.dialect.normalize_identifier(identifier).name + if identifier + else None + ) + + targets = {normalize(expression.this.this)} + + if alias: + targets.add(normalize(alias.this)) + + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node, + copy=False, + ) + + return self.merge_sql(expression) diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 11af17b..6e229b3 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -22,6 +22,7 @@ class Doris(MySQL): "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, "DATE_TRUNC": parse_timestamp_trunc, "REGEXP": exp.RegexpLike.from_arg_list, + "TO_DATE": exp.TsOrDsToDate.from_arg_list, } class Generator(MySQL.Generator): @@ -34,21 +35,26 @@ class Doris(MySQL): exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } + LAST_DAY_SUPPORTS_DATE_PART = False + TIMESTAMP_FUNC_TYPES = set() TRANSFORMS = { **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArgMax: rename_func("MAX_BY"), + exp.ArgMin: rename_func("MIN_BY"), exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.CurrentTimestamp: lambda *_: "NOW()", exp.DateTrunc: lambda self, e: self.func( "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" ), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, + exp.Map: rename_func("ARRAY_MAP"), exp.RegexpLike: rename_func("REGEXP"), exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), - exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Split: rename_func("SPLIT_BY_STRING"), exp.TimeStrToDate: rename_func("TO_DATE"), @@ -63,5 +69,4 @@ class Doris(MySQL): "FROM_UNIXTIME", e.this, time_format("doris")(self, e) ), exp.UnixToTime: rename_func("FROM_UNIXTIME"), - exp.Map: rename_func("ARRAY_MAP"), } diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index c9b31a0..6bca9e7 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -12,7 +12,6 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, ) @@ -99,6 +98,7 @@ class Drill(Dialect): TABLE_HINTS = False QUERY_HINTS = False NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -150,7 +150,6 @@ class Drill(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", - exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index cd9d529..2343b35 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -22,15 +22,15 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, no_timestamp_sql, pivot_column_names, + prepend_dollar_to_path, regexp_extract_sql, rename_func, str_position_sql, str_to_time_sql, timestamptrunc_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, ) -from sqlglot.helper import seq_get +from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -141,11 +141,25 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str return f"EPOCH_MS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"MAKE_TIMESTAMP({timestamp})" - if scale == exp.UnixToTime.NANOS: - return f"TO_TIMESTAMP({timestamp} / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))" + + +def _rename_unless_within_group( + a: str, b: str +) -> t.Callable[[DuckDB.Generator, exp.Expression], str]: + return ( + lambda self, expression: self.func(a, *flatten(expression.args.values())) + if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup) + else self.func(b, *flatten(expression.args.values())) + ) + + +def _parse_struct_pack(args: t.List) -> exp.Struct: + args_with_columns_as_identifiers = [ + exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args + ] + return exp.Struct.from_arg_list(args_with_columns_as_identifiers) class DuckDB(Dialect): @@ -183,6 +197,11 @@ class DuckDB(Dialect): "TIMESTAMP_US": TokenType.TIMESTAMP, } + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.PARAMETER, + } + class Parser(parser.Parser): BITWISE = { **parser.Parser.BITWISE, @@ -209,10 +228,12 @@ class DuckDB(Dialect): "EPOCH_MS": lambda args: exp.UnixToTime( this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS ), + "JSON": exp.ParseJSON.from_arg_list, "LIST_HAS": exp.ArrayContains.from_arg_list, "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "MAKE_TIME": exp.TimeFromParts.from_arg_list, "MAKE_TIMESTAMP": _parse_make_timestamp, "MEDIAN": lambda args: exp.PercentileCont( this=seq_get(args, 0), expression=exp.Literal.number(0.5) @@ -234,7 +255,7 @@ class DuckDB(Dialect): "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_TO_ARRAY": exp.Split.from_arg_list, "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), - "STRUCT_PACK": exp.Struct.from_arg_list, + "STRUCT_PACK": _parse_struct_pack, "STR_SPLIT": exp.Split.from_arg_list, "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, @@ -250,6 +271,13 @@ class DuckDB(Dialect): TokenType.ANTI, } + PLACEHOLDER_PARSERS = { + **parser.Parser.PLACEHOLDER_PARSERS, + TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) + else None, + } + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -268,7 +296,7 @@ class DuckDB(Dialect): return this - def _parse_struct_types(self) -> t.Optional[exp.Expression]: + def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: return self._parse_field_def() def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: @@ -285,6 +313,10 @@ class DuckDB(Dialect): RENAME_TABLE_WITH_DB = False NVL2_SUPPORTED = False SEMI_ANTI_JOIN_WITH_SIDE = False + TABLESAMPLE_KEYWORDS = "USING SAMPLE" + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + LAST_DAY_SUPPORTS_DATE_PART = False + JSON_KEY_VALUE_PAIR_SEP = "," TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -311,7 +343,7 @@ class DuckDB(Dialect): exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateSub: _date_delta_sql, exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this + "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", @@ -322,11 +354,11 @@ class DuckDB(Dialect): exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.IsInf: rename_func("ISINF"), exp.IsNan: rename_func("ISNAN"), + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONFormat: _json_format_sql, - exp.JSONBExtract: arrow_json_extract_sql, - exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.MonthsBetween: lambda self, e: self.func( @@ -336,8 +368,8 @@ class DuckDB(Dialect): exp.cast(e.this, "timestamp", copy=True), ), exp.ParseJSON: rename_func("JSON"), - exp.PercentileCont: rename_func("QUANTILE_CONT"), - exp.PercentileDisc: rename_func("QUANTILE_DISC"), + exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"), + exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"), # DuckDB doesn't allow qualified columns inside of PIVOT expressions. # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62 exp.Pivot: transforms.preprocess([transforms.unqualify_columns]), @@ -362,7 +394,9 @@ class DuckDB(Dialect): exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, exp.Timestamp: no_timestamp_sql, - exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), + exp.TimestampDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this + ), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, @@ -373,11 +407,10 @@ class DuckDB(Dialect): exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: lambda self, e: self.func( "DATE_DIFF", - f"'{e.args.get('unit') or 'day'}'", + f"'{e.args.get('unit') or 'DAY'}'", exp.cast(e.expression, "TIMESTAMP"), exp.cast(e.this, "TIMESTAMP"), ), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"), exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", @@ -410,6 +443,49 @@ class DuckDB(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: + nano = expression.args.get("nano") + if nano is not None: + expression.set( + "sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0) + ) + + return rename_func("MAKE_TIME")(self, expression) + + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + sec = expression.args["sec"] + + milli = expression.args.get("milli") + if milli is not None: + sec += milli.pop() / exp.Literal.number(1000.0) + + nano = expression.args.get("nano") + if nano is not None: + sec += nano.pop() / exp.Literal.number(1000000000.0) + + if milli or nano: + expression.set("sec", sec) + + return rename_func("MAKE_TIMESTAMP")(self, expression) + + def tablesample_sql( + self, + expression: exp.TableSample, + sep: str = " AS ", + tablesample_keyword: t.Optional[str] = None, + ) -> str: + if not isinstance(expression.parent, exp.Select): + # This sample clause only applies to a single source, not the entire resulting relation + tablesample_keyword = "TABLESAMPLE" + + return super().tablesample_sql( + expression, sep=sep, tablesample_keyword=tablesample_keyword + ) + + def getpath_sql(self, expression: exp.GetPath) -> str: + expression = prepend_dollar_to_path(expression) + return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}" + def interval_sql(self, expression: exp.Interval) -> str: multiplier: t.Optional[int] = None unit = expression.text("unit").lower() @@ -420,11 +496,14 @@ class DuckDB(Dialect): multiplier = 90 if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})" return super().interval_sql(expression) - def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS " - ) -> str: - return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep) + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: + if isinstance(expression.parent, exp.UserDefinedFunction): + return self.sql(expression, "this") + return super().columndef_sql(expression, sep) + + def placeholder_sql(self, expression: exp.Placeholder) -> str: + return f"${expression.name}" if expression.name else "?" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 65c85bb..dffa41e 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -418,13 +418,13 @@ class Hive(Dialect): class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False - TABLESAMPLE_SIZE_IS_PERCENT = True JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Insert, @@ -523,7 +523,6 @@ class Hive(Dialect): exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), - exp.LastDateOfMonth: rename_func("LAST_DAY"), exp.National: lambda self, e: self.national_sql(e, prefix=""), exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5fe3d82..21a9657 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, isnull_to_is_null, - json_keyvalue_comma_sql, locate_to_strposition, max_or_greatest, min_or_least, @@ -21,6 +20,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, parse_date_delta_with_interval, + path_to_jsonpath, rename_func, strposition_to_locate_sql, ) @@ -37,21 +37,21 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") - unit = expression.text("unit") + unit = expression.text("unit").upper() - if unit == "day": + if unit == "DAY": return f"DATE({expr})" - if unit == "week": + if unit == "WEEK": concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" - elif unit == "month": + elif unit == "MONTH": concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" date_format = "%Y %c %e" - elif unit == "quarter": + elif unit == "QUARTER": concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" date_format = "%Y %c %e" - elif unit == "year": + elif unit == "YEAR": concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: @@ -292,9 +292,15 @@ class MySQL(Dialect): "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, + "MAKETIME": exp.TimeFromParts.from_arg_list, + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "MONTHNAME": lambda args: exp.TimeToStr( this=exp.TsOrDsToDate(this=seq_get(args, 0)), format=exp.Literal.string("%B"), @@ -308,11 +314,6 @@ class MySQL(Dialect): ) + 1 ), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "WEEK": lambda args: exp.Week( this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1) ), @@ -441,6 +442,7 @@ class MySQL(Dialect): } LOG_DEFAULTS_TO_LN = True + STRING_ALIASES = True def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() @@ -620,13 +622,15 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True - NULL_ORDERING_SUPPORTED = False + NULL_ORDERING_SUPPORTED = None JOIN_HINTS = False TABLE_HINTS = True DUPLICATE_KEY_UPDATE_WITH_SET = False QUERY_HINT_SEP = " " VALUES_AS_TABLE = False NVL2_SUPPORTED = False + LAST_DAY_SUPPORTS_DATE_PART = False + JSON_KEY_VALUE_PAIR_SEP = "," TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -642,15 +646,16 @@ class MySQL(Dialect): exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), + exp.GetPath: path_to_jsonpath(), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Month: _remove_ts_or_ds_to_date(), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}", + exp.ParseJSON: lambda self, e: self.sql(e, "this"), exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( [ @@ -665,6 +670,7 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.Stuff: rename_func("INSERT"), exp.TableSample: no_tablesample_sql, + exp.TimeFromParts: rename_func("MAKETIME"), exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 51dbd53..6ad3718 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -53,6 +53,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar: class Oracle(Dialect): ALIAS_POST_TABLESAMPLE = True LOCKING_READS_SUPPORTED = True + TABLESAMPLE_SIZE_IS_PERCENT = True # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -81,6 +82,7 @@ class Oracle(Dialect): "WW": "%W", # Week of year (1-53) "YY": "%y", # 15 "YYYY": "%Y", # 2015 + "FF6": "%f", # only 6 digits are supported in python formats } class Parser(parser.Parser): @@ -91,6 +93,8 @@ class Oracle(Dialect): **parser.Parser.FUNCTIONS, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_CHAR": to_char, + "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"), + "TO_DATE": format_time_lambda(exp.StrToDate, "oracle"), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -107,6 +111,11 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()), + } + TYPE_LITERAL_PARSERS = { exp.DataType.Type.DATE: lambda self, this, _: self.expression( exp.DateStrToDate, this=this @@ -153,8 +162,10 @@ class Oracle(Dialect): COLUMN_JOIN_MARKS_SUPPORTED = True DATA_TYPE_SPECIFIERS_ALLOWED = True ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False - LIMIT_FETCH = "FETCH" + TABLESAMPLE_KEYWORDS = "SAMPLE" + LAST_DAY_SUPPORTS_DATE_PART = False + SUPPORTS_SELECT_INTO = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -186,6 +197,7 @@ class Oracle(Dialect): ] ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), @@ -201,6 +213,10 @@ class Oracle(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str: + this = expression.this + return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP" + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" @@ -233,8 +249,10 @@ class Oracle(Dialect): "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NVARCHAR2": TokenType.NVARCHAR, + "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY, "SAMPLE": TokenType.TABLE_SAMPLE, "START": TokenType.BEGIN, + "SYSDATE": TokenType.CURRENT_TIMESTAMP, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index e274877..1ca0a78 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -13,11 +13,12 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, max_or_greatest, + merge_without_target_sql, min_or_least, + no_last_day_sql, no_map_from_entries_sql, no_paren_current_date_sql, no_pivot_sql, - no_tablesample_sql, no_trycast_sql, parse_timestamp_trunc, rename_func, @@ -27,7 +28,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, trim_sql, ts_or_ds_add_cast, - ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser @@ -188,36 +188,6 @@ def _to_timestamp(args: t.List) -> exp.Expression: return format_time_lambda(exp.StrToTime, "postgres")(args) -def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str: - def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: - """Remove table refs from columns in when statements.""" - if isinstance(expression, exp.Merge): - alias = expression.this.args.get("alias") - - normalize = ( - lambda identifier: self.dialect.normalize_identifier(identifier).name - if identifier - else None - ) - - targets = {normalize(expression.this.this)} - - if alias: - targets.add(normalize(alias.this)) - - for when in expression.expressions: - when.transform( - lambda node: exp.column(node.this) - if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets - else node, - copy=False, - ) - - return expression - - return transforms.preprocess([_remove_target_from_merge])(self, expression) - - class Postgres(Dialect): INDEX_OFFSET = 1 TYPED_DIVISION = True @@ -316,6 +286,8 @@ class Postgres(Dialect): **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, "GENERATE_SERIES": _generate_series, + "MAKE_TIME": exp.TimeFromParts.from_arg_list, + "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list, "NOW": exp.CurrentTimestamp.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "TO_TIMESTAMP": _to_timestamp, @@ -387,12 +359,18 @@ class Postgres(Dialect): class Generator(generator.Generator): SINGLE_STRING_INTERVAL = True + RENAME_TABLE_WITH_DB = False LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False NVL2_SUPPORTED = False PARAMETER_TOKEN = "$" + TABLESAMPLE_SIZE_IS_ROWS = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True + # https://www.postgresql.org/docs/current/sql-createtable.html + SUPPORTS_UNLOGGED_TABLES = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -430,12 +408,13 @@ class Postgres(Dialect): exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), + exp.LastDay: no_last_day_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Max: max_or_greatest, exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, - exp.Merge: _merge_sql, + exp.Merge: merge_without_target_sql, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess( [transforms.add_within_group_for_percentiles] @@ -458,16 +437,16 @@ class Postgres(Dialect): exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, + exp.TimeFromParts: rename_func("MAKE_TIME"), + exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TableSample: no_tablesample_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: _date_add_sql("+"), exp.TsOrDsDiff: _date_diff_sql, - exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), exp.Variance: rename_func("VAR_SAMP"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 360ab65..9b421e7 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_safe_divide_sql, no_timestamp_sql, + path_to_jsonpath, regexp_extract_sql, rename_func, right_to_substring_sql, @@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: expression = ts_or_ds_add_cast(expression) - unit = exp.Literal.string(expression.text("unit") or "day") + unit = exp.Literal.string(expression.text("unit") or "DAY") return self.func("DATE_ADD", unit, expression.expression, expression.this) def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: this = exp.cast(expression.this, "TIMESTAMP") expr = exp.cast(expression.expression, "TIMESTAMP") - unit = exp.Literal.string(expression.text("unit") or "day") + unit = exp.Literal.string(expression.text("unit") or "DAY") return self.func("DATE_DIFF", unit, expr, this) @@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) -def _parse_element_at(args: t.List) -> exp.Bracket: - this = seq_get(args, 0) - index = seq_get(args, 1) - assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) - return exp.Bracket(this=this, expressions=[index], offset=1, safe=True) - - def _unnest_sequence(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Table): if isinstance(expression.this, exp.GenerateSeries): @@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str timestamp = self.sql(expression, "this") if scale in (None, exp.UnixToTime.SECONDS): return rename_func("FROM_UNIXTIME")(self, expression) - if scale == exp.UnixToTime.MILLIS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)" - if scale == exp.UnixToTime.MICROS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)" - if scale == exp.UnixToTime.NANOS: - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))" def _to_int(expression: exp.Expression) -> exp.Expression: @@ -215,6 +202,7 @@ class Presto(Dialect): STRICT_STRING_CONCAT = True SUPPORTS_SEMI_ANTI_JOIN = False TYPED_DIVISION = True + TABLESAMPLE_SIZE_IS_PERCENT = True # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -258,7 +246,9 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, - "ELEMENT_AT": _parse_element_at, + "ELEMENT_AT": lambda args: exp.Bracket( + this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True + ), "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, "FROM_UTF8": lambda args: exp.Decode( @@ -344,20 +334,20 @@ class Presto(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "day"), + exp.Literal.string(e.text("unit") or "DAY"), _to_int( e.expression, ), e.this, ), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "day"), + exp.Literal.string(e.text("unit") or "DAY"), _to_int(e.expression * -1), e.this, ), @@ -366,6 +356,7 @@ class Presto(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, + exp.GetPath: path_to_jsonpath(), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.GroupConcat: lambda self, e: self.func( "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator") @@ -376,6 +367,7 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.ParseJSON: rename_func("JSON_PARSE"), exp.Last: _first_last_sql, + exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this), exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -446,7 +438,7 @@ class Presto(Dialect): return super().bracket_sql(expression) def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions): + if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions): self.unsupported("Struct with key-value definitions is unsupported.") return self.function_fallback_sql(expression) @@ -454,8 +446,8 @@ class Presto(Dialect): def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") - if expression.this and unit.lower().startswith("week"): - return f"({expression.this.name} * INTERVAL '7' day)" + if expression.this and unit.startswith("WEEK"): + return f"({expression.this.name} * INTERVAL '7' DAY)" return super().interval_sql(expression) def transaction_sql(self, expression: exp.Transaction) -> str: diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 7382e7c..7194d81 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -9,8 +9,8 @@ from sqlglot.dialects.dialect import ( concat_ws_to_dpipe_sql, date_delta_sql, generatedasidentitycolumnconstraint_sql, + no_tablesample_sql, rename_func, - ts_or_ds_to_date_sql, ) from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get @@ -123,6 +123,27 @@ class Redshift(Postgres): self._retreat(index) return None + def _parse_query_modifiers( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + this = super()._parse_query_modifiers(this) + + if this: + refs = set() + + for i, join in enumerate(this.args.get("joins", [])): + refs.add( + ( + this.args["from"] if i == 0 else this.args["joins"][i - 1] + ).alias_or_name.lower() + ) + table = join.this + + if isinstance(table, exp.Table): + if table.parts[0].name.lower() in refs: + table.replace(table.to_column()) + return this + class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] @@ -144,11 +165,11 @@ class Redshift(Postgres): class Generator(Postgres.Generator): LOCKING_READS_SUPPORTED = False - RENAME_TABLE_WITH_DB = False QUERY_HINTS = False VALUES_AS_TABLE = False TZ_TO_WITH_TIME_ZONE = True NVL2_SUPPORTED = True + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -184,9 +205,9 @@ class Redshift(Postgres): [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.TableSample: no_tablesample_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD"), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots @@ -198,6 +219,9 @@ class Redshift(Postgres): # Redshift supports ANY_VALUE(..) TRANSFORMS.pop(exp.AnyValue) + # Redshift supports LAST_DAY(..) + TRANSFORMS.pop(exp.LastDay) + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} def with_properties(self, properties: exp.Properties) -> str: diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 8925181..a8e4a42 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import ( rename_func, timestamptrunc_sql, timestrtotime_sql, - ts_or_ds_to_date_sql, var_map_sql, ) from sqlglot.expressions import Literal @@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, if second_arg.is_string: # case: [ , ] return format_time_lambda(exp.StrToTime, "snowflake")(args) - - # case: [ , ] - if second_arg.name not in ["0", "3", "9"]: - raise ValueError( - f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" - ) - - if second_arg.name == "0": - timescale = exp.UnixToTime.SECONDS - elif second_arg.name == "3": - timescale = exp.UnixToTime.MILLIS - elif second_arg.name == "9": - timescale = exp.UnixToTime.NANOS - - return exp.UnixToTime(this=first_arg, scale=timescale) + return exp.UnixToTime(this=first_arg, scale=second_arg) from sqlglot.optimizer.simplify import simplify_literals @@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: def _parse_datediff(args: t.List) -> exp.DateDiff: - return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) - - -def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") - if scale in (None, exp.UnixToTime.SECONDS): - return f"TO_TIMESTAMP({timestamp})" - if scale == exp.UnixToTime.MILLIS: - return f"TO_TIMESTAMP({timestamp}, 3)" - if scale == exp.UnixToTime.MICROS: - return f"TO_TIMESTAMP({timestamp} / 1000, 3)" - if scale == exp.UnixToTime.NANOS: - return f"TO_TIMESTAMP({timestamp}, 9)" - - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0)) + ) # https://docs.snowflake.com/en/sql-reference/functions/date_part.html @@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: self._match(TokenType.COMMA) expression = self._parse_bitwise() - + this = _map_date_part(this) name = this.name.upper() + if name.startswith("EPOCH"): - if name.startswith("EPOCH_MILLISECOND"): + if name == "EPOCH_MILLISECOND": scale = 10**3 - elif name.startswith("EPOCH_MICROSECOND"): + elif name == "EPOCH_MICROSECOND": scale = 10**6 - elif name.startswith("EPOCH_NANOSECOND"): + elif name == "EPOCH_NANOSECOND": scale = 10**9 else: scale = None @@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser] return _parse +DATE_PART_MAPPING = { + "Y": "YEAR", + "YY": "YEAR", + "YYY": "YEAR", + "YYYY": "YEAR", + "YR": "YEAR", + "YEARS": "YEAR", + "YRS": "YEAR", + "MM": "MONTH", + "MON": "MONTH", + "MONS": "MONTH", + "MONTHS": "MONTH", + "D": "DAY", + "DD": "DAY", + "DAYS": "DAY", + "DAYOFMONTH": "DAY", + "WEEKDAY": "DAYOFWEEK", + "DOW": "DAYOFWEEK", + "DW": "DAYOFWEEK", + "WEEKDAY_ISO": "DAYOFWEEKISO", + "DOW_ISO": "DAYOFWEEKISO", + "DW_ISO": "DAYOFWEEKISO", + "YEARDAY": "DAYOFYEAR", + "DOY": "DAYOFYEAR", + "DY": "DAYOFYEAR", + "W": "WEEK", + "WK": "WEEK", + "WEEKOFYEAR": "WEEK", + "WOY": "WEEK", + "WY": "WEEK", + "WEEK_ISO": "WEEKISO", + "WEEKOFYEARISO": "WEEKISO", + "WEEKOFYEAR_ISO": "WEEKISO", + "Q": "QUARTER", + "QTR": "QUARTER", + "QTRS": "QUARTER", + "QUARTERS": "QUARTER", + "H": "HOUR", + "HH": "HOUR", + "HR": "HOUR", + "HOURS": "HOUR", + "HRS": "HOUR", + "M": "MINUTE", + "MI": "MINUTE", + "MIN": "MINUTE", + "MINUTES": "MINUTE", + "MINS": "MINUTE", + "S": "SECOND", + "SEC": "SECOND", + "SECONDS": "SECOND", + "SECS": "SECOND", + "MS": "MILLISECOND", + "MSEC": "MILLISECOND", + "MILLISECONDS": "MILLISECOND", + "US": "MICROSECOND", + "USEC": "MICROSECOND", + "MICROSECONDS": "MICROSECOND", + "NS": "NANOSECOND", + "NSEC": "NANOSECOND", + "NANOSEC": "NANOSECOND", + "NSECOND": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "NANOSECS": "NANOSECOND", + "NSECONDS": "NANOSECOND", + "EPOCH": "EPOCH_SECOND", + "EPOCH_SECONDS": "EPOCH_SECOND", + "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", + "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", + "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", + "TZH": "TIMEZONE_HOUR", + "TZM": "TIMEZONE_MINUTE", +} + + +@t.overload +def _map_date_part(part: exp.Expression) -> exp.Var: + pass + + +@t.overload +def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + pass + + +def _map_date_part(part): + mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None + return exp.var(mapped) if mapped else part + + +def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: + trunc = date_trunc_to_time(args) + trunc.set("unit", _map_date_part(trunc.args["unit"])) + return trunc + + +def _parse_colon_get_path( + self: parser.Parser, this: t.Optional[exp.Expression] +) -> t.Optional[exp.Expression]: + while True: + path = self._parse_bitwise() + + # The cast :: operator has a lower precedence than the extraction operator :, so + # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH + if isinstance(path, exp.Cast): + target_type = path.to + path = path.this + else: + target_type = None + + if isinstance(path, exp.Expression): + path = exp.Literal.string(path.sql(dialect="snowflake")) + + # The extraction operator : is left-associative + this = self.expression(exp.GetPath, this=this, expression=path) + + if target_type: + this = exp.cast(this, target_type) + + if not self._match(TokenType.COLON): + break + + if self._match_set(self.RANGE_PARSERS): + this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this + + return this + + +def _parse_timestamp_from_parts(args: t.List) -> exp.Func: + if len(args) == 2: + # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, + # so we parse this into Anonymous for now instead of introducing complexity + return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) + + return exp.TimestampFromParts.from_arg_list(args) + + +def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: + """ + Snowflake doesn't allow columns referenced in UNPIVOT to be qualified, + so we need to unqualify them. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))") + >>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake")) + SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april)) + """ + if isinstance(expression, exp.Pivot) and expression.unpivot: + expression = transforms.unqualify_columns(expression) + + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -211,6 +336,8 @@ class Snowflake(Dialect): TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" SUPPORTS_USER_DEFINED_TYPES = False SUPPORTS_SEMI_ANTI_JOIN = False + PREFER_CTE_ALIAS_COLUMN = True + TABLESAMPLE_SIZE_IS_PERCENT = True TIME_MAPPING = { "YYYY": "%Y", @@ -276,14 +403,19 @@ class Snowflake(Dialect): "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _parse_convert_timezone, - "DATE_TRUNC": date_trunc_to_time, + "DATE_TRUNC": _date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=_map_date_part(seq_get(args, 0)), ), "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "FLATTEN": exp.Explode.from_arg_list, "IFF": exp.If.from_arg_list, + "LAST_DAY": lambda args: exp.LastDay( + this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) + ), "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, @@ -293,6 +425,8 @@ class Snowflake(Dialect): "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, + "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts, + "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts, "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, @@ -301,22 +435,17 @@ class Snowflake(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, + "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(), } FUNCTION_PARSERS.pop("TRIM") - COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, - TokenType.COLON: lambda self, this, path: self.expression( - exp.Bracket, this=this, expressions=[path] - ), - } - TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME} RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), + TokenType.COLON: _parse_colon_get_path, } ALTER_PARSERS = { @@ -344,6 +473,7 @@ class Snowflake(Dialect): SHOW_PARSERS = { "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "COLUMNS": _show_parser("COLUMNS"), } STAGED_FILE_SINGLE_TOKENS = { @@ -351,8 +481,18 @@ class Snowflake(Dialect): TokenType.MOD, TokenType.SLASH, } + FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: + if is_map: + # Keys are strings in Snowflake's objects, see also: + # - https://docs.snowflake.com/en/sql-reference/data-types-semistructured + # - https://docs.snowflake.com/en/sql-reference/functions/object_construct + return self._parse_slice(self._parse_string()) + + return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True)) + def _parse_lateral(self) -> t.Optional[exp.Lateral]: lateral = super()._parse_lateral() if not lateral: @@ -440,6 +580,8 @@ class Snowflake(Dialect): scope = None scope_kind = None + like = self._parse_string() if self._match(TokenType.LIKE) else None + if self._match(TokenType.IN): if self._match_text_seq("ACCOUNT"): scope_kind = "ACCOUNT" @@ -451,7 +593,9 @@ class Snowflake(Dialect): scope_kind = "TABLE" scope = self._parse_table() - return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + return self.expression( + exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind + ) def _parse_alter_table_swap(self) -> exp.SwapTable: self._match_text_seq("WITH") @@ -489,8 +633,12 @@ class Snowflake(Dialect): "MINUS": TokenType.EXCEPT, "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, + "REMOVE": TokenType.COMMAND, "RENAME": TokenType.REPLACE, + "RM": TokenType.COMMAND, "SAMPLE": TokenType.TABLE_SAMPLE, + "SQL_DOUBLE": TokenType.DOUBLE, + "SQL_VARCHAR": TokenType.VARCHAR, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -518,6 +666,8 @@ class Snowflake(Dialect): SUPPORTS_TABLE_COPY = False COLLATE_IS_FUNC = True LIMIT_ONLY_LITERALS = True + JSON_KEY_VALUE_PAIR_SEP = "," + INSERT_OVERWRITE = " OVERWRITE INTO" TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -545,6 +695,8 @@ class Snowflake(Dialect): ), exp.GroupConcat: rename_func("LISTAGG"), exp.If: if_sql(name="IFF", false_value="NULL"), + exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]", + exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -557,6 +709,7 @@ class Snowflake(Dialect): exp.PercentileDisc: transforms.preprocess( [transforms.add_within_group_for_percentiles] ), + exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]), exp.RegexpILike: _regexpilike_sql, exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( @@ -578,6 +731,9 @@ class Snowflake(Dialect): *(arg for expression in e.expressions for arg in expression.flatten()), ), exp.Stuff: rename_func("INSERT"), + exp.TimestampDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.unit, e.expression, e.this + ), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( @@ -589,8 +745,7 @@ class Snowflake(Dialect): exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), - exp.UnixToTime: _unix_to_time_sql, + exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.Xor: rename_func("BOOLXOR"), @@ -612,6 +767,14 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + milli = expression.args.get("milli") + if milli is not None: + milli_to_nano = milli.pop() * exp.Literal.number(1000000) + expression.set("nano", milli_to_nano) + + return rename_func("TIMESTAMP_FROM_PARTS")(self, expression) + def trycast_sql(self, expression: exp.TryCast) -> str: value = expression.this @@ -657,6 +820,9 @@ class Snowflake(Dialect): return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: + like = self.sql(expression, "like") + like = f" LIKE {like}" if like else "" + scope = self.sql(expression, "scope") scope = f" {scope}" if scope else "" @@ -664,7 +830,7 @@ class Snowflake(Dialect): if scope_kind: scope_kind = f" IN {scope_kind}" - return f"SHOW {expression.name}{scope_kind}{scope}" + return f"SHOW {expression.name}{like}{scope_kind}{scope}" def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index aa09f53..e27ba18 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -48,11 +48,8 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str return f"TIMESTAMP_MILLIS({timestamp})" if scale == exp.UnixToTime.MICROS: return f"TIMESTAMP_MICROS({timestamp})" - if scale == exp.UnixToTime.NANOS: - return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)" - self.unsupported(f"Unsupported scale for timestamp: {scale}.") - return "" + return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))" def _unalias_pivot(expression: exp.Expression) -> exp.Expression: @@ -93,12 +90,7 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) """ if isinstance(expression, exp.Pivot): - expression.args["field"].transform( - lambda node: exp.column(node.output_name, quoted=node.this.quoted) - if isinstance(node, exp.Column) - else node, - copy=False, - ) + expression.set("field", transforms.unqualify_columns(expression.args["field"])) return expression @@ -234,7 +226,7 @@ class Spark2(Hive): def struct_sql(self, expression: exp.Struct) -> str: args = [] for arg in expression.expressions: - if isinstance(arg, self.KEY_VALUE_DEFINITONS): + if isinstance(arg, self.KEY_VALUE_DEFINITIONS): if isinstance(arg, exp.Bracket): args.append(exp.alias_(arg.this, arg.expressions[0].name)) else: diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 9bac51c..244a96e 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -78,6 +78,7 @@ class SQLite(Dialect): **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } + STRING_ALIASES = True class Generator(generator.Generator): JOIN_HINTS = False diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 0ccc567..6dbad15 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -175,6 +175,8 @@ class Teradata(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + TABLESAMPLE_KEYWORDS = "SAMPLE" + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -214,7 +216,10 @@ class Teradata(Dialect): return self.cast_sql(expression, safe_prefix="TRY") def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + self, + expression: exp.TableSample, + sep: str = " AS ", + tablesample_keyword: t.Optional[str] = None, ) -> str: return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}" diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 3682ac7..eddb70a 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -1,6 +1,7 @@ from __future__ import annotations from sqlglot import exp +from sqlglot.dialects.dialect import merge_without_target_sql from sqlglot.dialects.presto import Presto @@ -11,6 +12,7 @@ class Trino(Presto): TRANSFORMS = { **Presto.Generator.TRANSFORMS, exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.Merge: merge_without_target_sql, } class Tokenizer(Presto.Tokenizer): diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 165a703..b9c347c 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, parse_date_delta, + path_to_jsonpath, rename_func, timestrtotime_sql, - ts_or_ds_to_date_sql, + trim_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr(this=this, format=fmt, culture=culture) -def _parse_eomonth(args: t.List) -> exp.Expression: - date = seq_get(args, 0) +def _parse_eomonth(args: t.List) -> exp.LastDay: + date = exp.TsOrDsToDate(this=seq_get(args, 0)) month_lag = seq_get(args, 1) - unit = DATE_DELTA_INTERVAL.get("month") if month_lag is None: - return exp.LastDateOfMonth(this=date) + this: exp.Expression = date + else: + unit = DATE_DELTA_INTERVAL.get("month") + this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit)) - # Remove month lag argument in parser as its compared with the number of arguments of the resulting class - args.remove(month_lag) - - return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) + return exp.LastDay(this=this) def _parse_hashbytes(args: t.List) -> exp.Expression: @@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: return exp.func("HASHBYTES", *args) -DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"} +DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"} def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: - fmt = ( - expression.args["format"] - if isinstance(expression, exp.NumberToStr) - else exp.Literal.string( - format_time( - expression.text("format"), - t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), - ) - ) - ) + fmt = expression.args["format"] - # There is no format for "quarter" - if fmt.name.lower() in DATEPART_ONLY_FORMATS: - return self.func("DATEPART", fmt.name, expression.this) + if not isinstance(expression, exp.NumberToStr): + if fmt.is_string: + mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING) - return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) + name = (mapped_fmt or "").upper() + if name in DATEPART_ONLY_FORMATS: + return self.func("DATEPART", name, expression.this) + + fmt_sql = self.sql(exp.Literal.string(mapped_fmt)) + else: + fmt_sql = self.format_time(expression) or self.sql(fmt) + else: + fmt_sql = self.sql(fmt) + + return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture")) def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: @@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: return expression +# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax +def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: + return exp.TimestampFromParts( + year=seq_get(args, 0), + month=seq_get(args, 1), + day=seq_get(args, 2), + hour=seq_get(args, 3), + min=seq_get(args, 4), + sec=seq_get(args, 5), + milli=seq_get(args, 6), + ) + + +# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax +def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: + return exp.TimeFromParts( + hour=seq_get(args, 0), + min=seq_get(args, 1), + sec=seq_get(args, 2), + fractions=seq_get(args, 3), + precision=seq_get(args, 4), + ) + + class TSQL(Dialect): NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" @@ -352,7 +377,7 @@ class TSQL(Dialect): } class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = ['"', ("[", "]")] + IDENTIFIERS = [("[", "]"), '"'] QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] VAR_SINGLE_TOKENS = {"@", "$", "#"} @@ -362,6 +387,7 @@ class TSQL(Dialect): "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DECLARE": TokenType.COMMAND, + "EXEC": TokenType.COMMAND, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "NTEXT": TokenType.TEXT, @@ -397,6 +423,7 @@ class TSQL(Dialect): "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), + "DATETIMEFROMPARTS": _parse_datetimefromparts, "EOMONTH": _parse_eomonth, "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, @@ -411,6 +438,7 @@ class TSQL(Dialect): "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, "SYSTEM_USER": exp.CurrentUser.from_arg_list, + "TIMEFROMPARTS": _parse_timefromparts, } JOIN_HINTS = { @@ -440,6 +468,7 @@ class TSQL(Dialect): LOG_DEFAULTS_TO_LN = True ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False + STRING_ALIASES = True def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -630,8 +659,10 @@ class TSQL(Dialect): COMPUTED_COLUMN_WITH_TYPE = False CTE_RECURSIVE_KEYWORD_REQUIRED = False ENSURE_BOOLS = True - NULL_ORDERING_SUPPORTED = False + NULL_ORDERING_SUPPORTED = None SUPPORTS_SINGLE_ARG_CONCAT = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -667,13 +698,16 @@ class TSQL(Dialect): exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, + exp.GetPath: path_to_jsonpath("JSON_VALUE"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, exp.NumberToStr: _format_sql, + exp.ParseJSON: lambda self, e: self.sql(e, "this"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, @@ -689,9 +723,9 @@ class TSQL(Dialect): exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.Trim: trim_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } TRANSFORMS.pop(exp.ReturnsProperty) @@ -701,6 +735,46 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def lateral_op(self, expression: exp.Lateral) -> str: + cross_apply = expression.args.get("cross_apply") + if cross_apply is True: + return "CROSS APPLY" + if cross_apply is False: + return "OUTER APPLY" + + # TODO: perhaps we can check if the parent is a Join and transpile it appropriately + self.unsupported("LATERAL clause is not supported.") + return "LATERAL" + + def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: + nano = expression.args.get("nano") + if nano is not None: + nano.pop() + self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.") + + if expression.args.get("fractions") is None: + expression.set("fractions", exp.Literal.number(0)) + if expression.args.get("precision") is None: + expression.set("precision", exp.Literal.number(0)) + + return rename_func("TIMEFROMPARTS")(self, expression) + + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + zone = expression.args.get("zone") + if zone is not None: + zone.pop() + self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.") + + nano = expression.args.get("nano") + if nano is not None: + nano.pop() + self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.") + + if expression.args.get("milli") is None: + expression.set("milli", exp.Literal.number(0)) + + return rename_func("DATETIMEFROMPARTS")(self, expression) + def set_operation(self, expression: exp.Union, op: str) -> str: limit = expression.args.get("limit") if limit: -- cgit v1.2.3