summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
commit38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch)
tree64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot/dialects
parentReleasing debian version 20.4.0-1. (diff)
downloadsqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz
sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py91
-rw-r--r--sqlglot/dialects/clickhouse.py201
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py119
-rw-r--r--sqlglot/dialects/doris.py9
-rw-r--r--sqlglot/dialects/drill.py3
-rw-r--r--sqlglot/dialects/duckdb.py121
-rw-r--r--sqlglot/dialects/hive.py3
-rw-r--r--sqlglot/dialects/mysql.py34
-rw-r--r--sqlglot/dialects/oracle.py20
-rw-r--r--sqlglot/dialects/postgres.py49
-rw-r--r--sqlglot/dialects/presto.py40
-rw-r--r--sqlglot/dialects/redshift.py30
-rw-r--r--sqlglot/dialects/snowflake.py266
-rw-r--r--sqlglot/dialects/spark2.py14
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/dialects/teradata.py7
-rw-r--r--sqlglot/dialects/trino.py2
-rw-r--r--sqlglot/dialects/tsql.py128
19 files changed, 897 insertions, 243 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 7a573e7..0151e6c 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -16,20 +16,22 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
inline_array_sql,
- json_keyvalue_comma_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
+ path_to_jsonpath,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
ts_or_ds_add_cast,
- ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from typing_extensions import Literal
+
logger = logging.getLogger("sqlglot")
@@ -206,12 +208,17 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
- if scale == exp.UnixToTime.NANOS:
- # We need to cast to INT64 because that's what BQ expects
- return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))"
+
+
+def _parse_time(args: t.List) -> exp.Func:
+ if len(args) == 1:
+ return exp.TsOrDsToTime(this=args[0])
+ if len(args) == 3:
+ return exp.TimeFromParts.from_arg_list(args)
+
+ return exp.Anonymous(this="TIME", expressions=args)
class BigQuery(Dialect):
@@ -329,7 +336,13 @@ class BigQuery(Dialect):
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
+ "FORMAT_DATE": lambda args: exp.TimeToStr(
+ this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0)
+ ),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
+ "JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
+ this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
+ ),
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _parse_to_hex,
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
@@ -351,6 +364,7 @@ class BigQuery(Dialect):
this=seq_get(args, 0),
expression=seq_get(args, 1) or exp.Literal.string(","),
),
+ "TIME": _parse_time,
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
@@ -361,9 +375,7 @@ class BigQuery(Dialect):
"TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
- "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(
- this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS
- ),
+ "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)),
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
}
@@ -460,7 +472,15 @@ class BigQuery(Dialect):
return table
- def _parse_json_object(self) -> exp.JSONObject:
+ @t.overload
+ def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
+ ...
+
+ @t.overload
+ def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
+ ...
+
+ def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
array_kv_pair = seq_get(json_object.expressions, 0)
@@ -513,6 +533,10 @@ class BigQuery(Dialect):
UNNEST_WITH_ORDINALITY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
+ SUPPORTS_TABLE_ALIAS_COLUMNS = False
+ UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
+ JSON_KEY_VALUE_PAIR_SEP = ","
+ NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -525,6 +549,7 @@ class BigQuery(Dialect):
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}",
+ exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
@@ -536,13 +561,13 @@ class BigQuery(Dialect):
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
+ exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
- exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
@@ -578,16 +603,17 @@ class BigQuery(Dialect):
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
+ exp.TimeFromParts: rename_func("TIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
+ exp.TsOrDsToTime: rename_func("TIME"),
exp.Unhex: rename_func("FROM_HEX"),
+ exp.UnixDate: rename_func("UNIX_DATE"),
exp.UnixToTime: _unix_to_time_sql,
exp.Values: _derived_table_values_to_unnest,
exp.VariancePop: rename_func("VAR_POP"),
@@ -724,6 +750,26 @@ class BigQuery(Dialect):
"within",
}
+ def timetostr_sql(self, expression: exp.TimeToStr) -> str:
+ if isinstance(expression.this, exp.TsOrDsToDate):
+ this: exp.Expression = expression.this
+ else:
+ this = expression
+
+ return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})"
+
+ def struct_sql(self, expression: exp.Struct) -> str:
+ args = []
+ for expr in expression.expressions:
+ if isinstance(expr, self.KEY_VALUE_DEFINITIONS):
+ arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}"
+ else:
+ arg = self.sql(expr)
+
+ args.append(arg)
+
+ return self.func("STRUCT", *args)
+
def eq_sql(self, expression: exp.EQ) -> str:
# Operands of = cannot be NULL in BigQuery
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
@@ -760,7 +806,20 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
+ this = self.sql(expression, "this")
expressions = expression.expressions
+
+ if len(expressions) == 1:
+ arg = expressions[0]
+ if arg.type is None:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ arg = annotate_types(arg)
+
+ if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
+ # BQ doesn't support bracket syntax with string values
+ return f"{this}.{arg.name}"
+
expressions_sql = ", ".join(self.sql(e) for e in expressions)
offset = expression.args.get("offset")
@@ -768,13 +827,13 @@ class BigQuery(Dialect):
expressions_sql = f"OFFSET({expressions_sql})"
elif offset == 1:
expressions_sql = f"ORDINAL({expressions_sql})"
- else:
+ elif offset is not None:
self.unsupported(f"Unsupported array offset: {offset}")
if expression.args.get("safe"):
expressions_sql = f"SAFE_{expressions_sql}"
- return f"{self.sql(expression, 'this')}[{expressions_sql}]"
+ return f"{this}[{expressions_sql}]"
def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 870f402..f2e4fe1 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
+ date_delta_sql,
inline_array_sql,
no_pivot_sql,
rename_func,
@@ -22,16 +23,25 @@ def _lower_func(sql: str) -> str:
return sql[:index].lower() + sql[index:]
-def _quantile_sql(self, e):
+def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str:
quantile = e.args["quantile"]
args = f"({self.sql(e, 'this')})"
+
if isinstance(quantile, exp.Array):
func = self.func("quantiles", *quantile)
else:
func = self.func("quantile", quantile)
+
return func + args
+def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
+ if len(args) == 1:
+ return exp.CountIf(this=seq_get(args, 0))
+
+ return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If"))
+
+
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
@@ -53,6 +63,7 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ATTACH": TokenType.COMMAND,
+ "DATE32": TokenType.DATE32,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
"ENUM": TokenType.ENUM,
@@ -75,6 +86,8 @@ class ClickHouse(Dialect):
"UINT32": TokenType.UINT,
"UINT64": TokenType.UBIGINT,
"UINT8": TokenType.UTINYINT,
+ "IPV4": TokenType.IPV4,
+ "IPV6": TokenType.IPV6,
}
SINGLE_TOKENS = {
@@ -91,6 +104,8 @@ class ClickHouse(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
+ "ARRAYSUM": exp.ArraySum.from_arg_list,
+ "COUNTIF": _parse_count_if,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
@@ -110,6 +125,138 @@ class ClickHouse(Dialect):
"XOR": lambda args: exp.Xor(expressions=args),
}
+ AGG_FUNCTIONS = {
+ "count",
+ "min",
+ "max",
+ "sum",
+ "avg",
+ "any",
+ "stddevPop",
+ "stddevSamp",
+ "varPop",
+ "varSamp",
+ "corr",
+ "covarPop",
+ "covarSamp",
+ "entropy",
+ "exponentialMovingAverage",
+ "intervalLengthSum",
+ "kolmogorovSmirnovTest",
+ "mannWhitneyUTest",
+ "median",
+ "rankCorr",
+ "sumKahan",
+ "studentTTest",
+ "welchTTest",
+ "anyHeavy",
+ "anyLast",
+ "boundingRatio",
+ "first_value",
+ "last_value",
+ "argMin",
+ "argMax",
+ "avgWeighted",
+ "topK",
+ "topKWeighted",
+ "deltaSum",
+ "deltaSumTimestamp",
+ "groupArray",
+ "groupArrayLast",
+ "groupUniqArray",
+ "groupArrayInsertAt",
+ "groupArrayMovingAvg",
+ "groupArrayMovingSum",
+ "groupArraySample",
+ "groupBitAnd",
+ "groupBitOr",
+ "groupBitXor",
+ "groupBitmap",
+ "groupBitmapAnd",
+ "groupBitmapOr",
+ "groupBitmapXor",
+ "sumWithOverflow",
+ "sumMap",
+ "minMap",
+ "maxMap",
+ "skewSamp",
+ "skewPop",
+ "kurtSamp",
+ "kurtPop",
+ "uniq",
+ "uniqExact",
+ "uniqCombined",
+ "uniqCombined64",
+ "uniqHLL12",
+ "uniqTheta",
+ "quantile",
+ "quantiles",
+ "quantileExact",
+ "quantilesExact",
+ "quantileExactLow",
+ "quantilesExactLow",
+ "quantileExactHigh",
+ "quantilesExactHigh",
+ "quantileExactWeighted",
+ "quantilesExactWeighted",
+ "quantileTiming",
+ "quantilesTiming",
+ "quantileTimingWeighted",
+ "quantilesTimingWeighted",
+ "quantileDeterministic",
+ "quantilesDeterministic",
+ "quantileTDigest",
+ "quantilesTDigest",
+ "quantileTDigestWeighted",
+ "quantilesTDigestWeighted",
+ "quantileBFloat16",
+ "quantilesBFloat16",
+ "quantileBFloat16Weighted",
+ "quantilesBFloat16Weighted",
+ "simpleLinearRegression",
+ "stochasticLinearRegression",
+ "stochasticLogisticRegression",
+ "categoricalInformationValue",
+ "contingency",
+ "cramersV",
+ "cramersVBiasCorrected",
+ "theilsU",
+ "maxIntersections",
+ "maxIntersectionsPosition",
+ "meanZTest",
+ "quantileInterpolatedWeighted",
+ "quantilesInterpolatedWeighted",
+ "quantileGK",
+ "quantilesGK",
+ "sparkBar",
+ "sumCount",
+ "largestTriangleThreeBuckets",
+ }
+
+ AGG_FUNCTIONS_SUFFIXES = [
+ "If",
+ "Array",
+ "ArrayIf",
+ "Map",
+ "SimpleState",
+ "State",
+ "Merge",
+ "MergeState",
+ "ForEach",
+ "Distinct",
+ "OrDefault",
+ "OrNull",
+ "Resample",
+ "ArgMin",
+ "ArgMax",
+ ]
+
+ AGG_FUNC_MAPPING = (
+ lambda functions, suffixes: {
+ f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
+ }
+ )(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES)
+
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
FUNCTION_PARSERS = {
@@ -272,9 +419,18 @@ class ClickHouse(Dialect):
)
if isinstance(func, exp.Anonymous):
+ parts = self.AGG_FUNC_MAPPING.get(func.this)
params = self._parse_func_params(func)
if params:
+ if parts and parts[1]:
+ return self.expression(
+ exp.CombinedParameterizedAgg,
+ this=func.this,
+ expressions=func.expressions,
+ params=params,
+ parts=parts,
+ )
return self.expression(
exp.ParameterizedAgg,
this=func.this,
@@ -282,6 +438,20 @@ class ClickHouse(Dialect):
params=params,
)
+ if parts:
+ if parts[1]:
+ return self.expression(
+ exp.CombinedAggFunc,
+ this=func.this,
+ expressions=func.expressions,
+ parts=parts,
+ )
+ return self.expression(
+ exp.AnonymousAggFunc,
+ this=func.this,
+ expressions=func.expressions,
+ )
+
return func
def _parse_func_params(
@@ -329,6 +499,9 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
TABLESAMPLE_REQUIRES_PARENS = False
+ TABLESAMPLE_SIZE_IS_ROWS = False
+ TABLESAMPLE_KEYWORDS = "SAMPLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -348,6 +521,7 @@ class ClickHouse(Dialect):
**STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
+ exp.DataType.Type.DATE32: "Date32",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.ENUM: "Enum",
@@ -372,24 +546,23 @@ class ClickHouse(Dialect):
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.UTINYINT: "UInt8",
+ exp.DataType.Type.IPV4: "IPv4",
+ exp.DataType.Type.IPV6: "IPv6",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
- exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
+ exp.ArraySum: rename_func("arraySum"),
exp.ArgMax: arg_max_or_min_no_count("argMax"),
exp.ArgMin: arg_max_or_min_no_count("argMin"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
+ exp.CountIf: rename_func("countIf"),
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
- exp.DateAdd: lambda self, e: self.func(
- "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
- ),
- exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
- ),
+ exp.DateAdd: date_delta_sql("DATE_ADD"),
+ exp.DateDiff: date_delta_sql("DATE_DIFF"),
exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
@@ -400,6 +573,7 @@ class ClickHouse(Dialect):
exp.Quantile: _quantile_sql,
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.Rand: rename_func("randCanonical"),
+ exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
@@ -485,10 +659,19 @@ class ClickHouse(Dialect):
else "",
]
- def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
+ def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
params = self.expressions(expression, key="params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"
+ def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str:
+ return self.func(expression.name, *expression.expressions)
+
+ def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str:
+ return self.anonymousaggfunc_sql(expression)
+
+ def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str:
+ return self.parameterizedagg_sql(expression)
+
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 1c10a8b..8e55b6a 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -30,6 +30,8 @@ class Databricks(Spark):
}
class Generator(Spark.Generator):
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
exp.DateAdd: date_delta_sql("DATEADD"),
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index b7eef45..7664c40 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
ALIAS_POST_TABLESAMPLE = False
"""Determines whether or not the table alias comes after tablesample."""
+ TABLESAMPLE_SIZE_IS_PERCENT = False
+ """Determines whether or not a size in the table sample clause represents percentage."""
+
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
"""Specifies the strategy according to which identifiers should be normalized."""
@@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect):
For example, such columns may be excluded from `SELECT *` queries.
"""
+ PREFER_CTE_ALIAS_COLUMN = False
+ """
+ Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
+ HAVING clause of the CTE. This flag will cause the CTE alias columns to override
+ any projection aliases in the subquery.
+
+ For example,
+ WITH y(c) AS (
+ SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
+ ) SELECT c FROM y;
+
+ will be rewritten as
+
+ WITH y(c) AS (
+ SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
+ ) SELECT c FROM y;
+ """
+
# --- Autofilled ---
tokenizer_class = Tokenizer
@@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect):
result = cls.get(dialect_name.strip())
if not result:
- raise ValueError(f"Unknown dialect '{dialect_name}'.")
+ from difflib import get_close_matches
+
+ similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
+ if similar:
+ similar = f" Did you mean {similar}?"
+
+ raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
return result(**kwargs)
@@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
- return f"IF({d} <> 0, {n} / {d}, NULL)"
+ return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
@@ -695,7 +722,7 @@ def date_add_interval_sql(
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
- "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
+ "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
)
@@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
-def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
- def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
- _dialect = Dialect.get_or_raise(dialect)
- time_format = self.format_time(expression)
- if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
- return self.sql(
- exp.cast(
- exp.StrToTime(this=expression.this, format=expression.args["format"]),
- "date",
- )
- )
- return self.sql(exp.cast(expression.this, "date"))
-
- return _ts_or_ds_to_date_sql
-
-
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
-# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
-def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
- return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
-
-
def is_parse_json(expression: exp.Expression) -> bool:
return isinstance(expression, exp.ParseJSON) or (
isinstance(expression, exp.Cast) and expression.is_type("json")
@@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
expression = ts_or_ds_add_cast(expression)
return self.func(
- name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
+ name,
+ exp.var(expression.text("unit").upper() or "DAY"),
+ expression.expression,
+ expression.this,
)
return _delta_sql
+
+
+def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
+ from sqlglot.optimizer.simplify import simplify
+
+ # Makes sure the path will be evaluated correctly at runtime to include the path root.
+ # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
+ path = expression.expression
+ path = exp.func(
+ "if",
+ exp.func("startswith", path, "'['"),
+ exp.func("concat", "'$'", path),
+ exp.func("concat", "'$.'", path),
+ )
+
+ expression.expression.replace(simplify(path))
+ return expression
+
+
+def path_to_jsonpath(
+ name: str = "JSON_EXTRACT",
+) -> t.Callable[[Generator, exp.GetPath], str]:
+ def _transform(self: Generator, expression: exp.GetPath) -> str:
+ return rename_func(name)(self, prepend_dollar_to_path(expression))
+
+ return _transform
+
+
+def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
+ trunc_curr_date = exp.func("date_trunc", "month", expression.this)
+ plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
+ minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
+
+ return self.sql(exp.cast(minus_one_day, "date"))
+
+
+def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
+ """Remove table refs from columns in when statements."""
+ alias = expression.this.args.get("alias")
+
+ normalize = (
+ lambda identifier: self.dialect.normalize_identifier(identifier).name
+ if identifier
+ else None
+ )
+
+ targets = {normalize(expression.this.this)}
+
+ if alias:
+ targets.add(normalize(alias.this))
+
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node,
+ copy=False,
+ )
+
+ return self.merge_sql(expression)
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 11af17b..6e229b3 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -22,6 +22,7 @@ class Doris(MySQL):
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_TRUNC": parse_timestamp_trunc,
"REGEXP": exp.RegexpLike.from_arg_list,
+ "TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
class Generator(MySQL.Generator):
@@ -34,21 +35,26 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
+ LAST_DAY_SUPPORTS_DATE_PART = False
+
TIMESTAMP_FUNC_TYPES = set()
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.ArgMax: rename_func("MAX_BY"),
+ exp.ArgMin: rename_func("MIN_BY"),
exp.ArrayAgg: rename_func("COLLECT_LIST"),
+ exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
+ exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
- exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
@@ -63,5 +69,4 @@ class Doris(MySQL):
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
- exp.Map: rename_func("ARRAY_MAP"),
}
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index c9b31a0..6bca9e7 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -12,7 +12,6 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
)
@@ -99,6 +98,7 @@ class Drill(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -150,7 +150,6 @@ class Drill(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index cd9d529..2343b35 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -22,15 +22,15 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_timestamp_sql,
pivot_column_names,
+ prepend_dollar_to_path,
regexp_extract_sql,
rename_func,
str_position_sql,
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
)
-from sqlglot.helper import seq_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@@ -141,11 +141,25 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
return f"EPOCH_MS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"MAKE_TIMESTAMP({timestamp})"
- if scale == exp.UnixToTime.NANOS:
- return f"TO_TIMESTAMP({timestamp} / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))"
+
+
+def _rename_unless_within_group(
+ a: str, b: str
+) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
+ return (
+ lambda self, expression: self.func(a, *flatten(expression.args.values()))
+ if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
+ else self.func(b, *flatten(expression.args.values()))
+ )
+
+
+def _parse_struct_pack(args: t.List) -> exp.Struct:
+ args_with_columns_as_identifiers = [
+ exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
+ ]
+ return exp.Struct.from_arg_list(args_with_columns_as_identifiers)
class DuckDB(Dialect):
@@ -183,6 +197,11 @@ class DuckDB(Dialect):
"TIMESTAMP_US": TokenType.TIMESTAMP,
}
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.PARAMETER,
+ }
+
class Parser(parser.Parser):
BITWISE = {
**parser.Parser.BITWISE,
@@ -209,10 +228,12 @@ class DuckDB(Dialect):
"EPOCH_MS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
+ "JSON": exp.ParseJSON.from_arg_list,
"LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
+ "MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": _parse_make_timestamp,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
@@ -234,7 +255,7 @@ class DuckDB(Dialect):
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
- "STRUCT_PACK": exp.Struct.from_arg_list,
+ "STRUCT_PACK": _parse_struct_pack,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
@@ -250,6 +271,13 @@ class DuckDB(Dialect):
TokenType.ANTI,
}
+ PLACEHOLDER_PARSERS = {
+ **parser.Parser.PLACEHOLDER_PARSERS,
+ TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
+ else None,
+ }
+
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
@@ -268,7 +296,7 @@ class DuckDB(Dialect):
return this
- def _parse_struct_types(self) -> t.Optional[exp.Expression]:
+ def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
return self._parse_field_def()
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
@@ -285,6 +313,10 @@ class DuckDB(Dialect):
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
SEMI_ANTI_JOIN_WITH_SIDE = False
+ TABLESAMPLE_KEYWORDS = "USING SAMPLE"
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
+ JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -311,7 +343,7 @@ class DuckDB(Dialect):
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
+ "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
@@ -322,11 +354,11 @@ class DuckDB(Dialect):
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
+ exp.JSONBExtract: arrow_json_extract_sql,
+ exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONFormat: _json_format_sql,
- exp.JSONBExtract: arrow_json_extract_sql,
- exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MonthsBetween: lambda self, e: self.func(
@@ -336,8 +368,8 @@ class DuckDB(Dialect):
exp.cast(e.this, "timestamp", copy=True),
),
exp.ParseJSON: rename_func("JSON"),
- exp.PercentileCont: rename_func("QUANTILE_CONT"),
- exp.PercentileDisc: rename_func("QUANTILE_DISC"),
+ exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
+ exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"),
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
@@ -362,7 +394,9 @@ class DuckDB(Dialect):
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
- exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
+ exp.TimestampDiff: lambda self, e: self.func(
+ "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
+ ),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
@@ -373,11 +407,10 @@ class DuckDB(Dialect):
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
- f"'{e.args.get('unit') or 'day'}'",
+ f"'{e.args.get('unit') or 'DAY'}'",
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
@@ -410,6 +443,49 @@ class DuckDB(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
+ nano = expression.args.get("nano")
+ if nano is not None:
+ expression.set(
+ "sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0)
+ )
+
+ return rename_func("MAKE_TIME")(self, expression)
+
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ sec = expression.args["sec"]
+
+ milli = expression.args.get("milli")
+ if milli is not None:
+ sec += milli.pop() / exp.Literal.number(1000.0)
+
+ nano = expression.args.get("nano")
+ if nano is not None:
+ sec += nano.pop() / exp.Literal.number(1000000000.0)
+
+ if milli or nano:
+ expression.set("sec", sec)
+
+ return rename_func("MAKE_TIMESTAMP")(self, expression)
+
+ def tablesample_sql(
+ self,
+ expression: exp.TableSample,
+ sep: str = " AS ",
+ tablesample_keyword: t.Optional[str] = None,
+ ) -> str:
+ if not isinstance(expression.parent, exp.Select):
+ # This sample clause only applies to a single source, not the entire resulting relation
+ tablesample_keyword = "TABLESAMPLE"
+
+ return super().tablesample_sql(
+ expression, sep=sep, tablesample_keyword=tablesample_keyword
+ )
+
+ def getpath_sql(self, expression: exp.GetPath) -> str:
+ expression = prepend_dollar_to_path(expression)
+ return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
+
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
unit = expression.text("unit").lower()
@@ -420,11 +496,14 @@ class DuckDB(Dialect):
multiplier = 90
if multiplier:
- return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
+ return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})"
return super().interval_sql(expression)
- def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
- ) -> str:
- return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
+ if isinstance(expression.parent, exp.UserDefinedFunction):
+ return self.sql(expression, "this")
+ return super().columndef_sql(expression, sep)
+
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
+ return f"${expression.name}" if expression.name else "?"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 65c85bb..dffa41e 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -418,13 +418,13 @@ class Hive(Dialect):
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
- TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
+ LAST_DAY_SUPPORTS_DATE_PART = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@@ -523,7 +523,6 @@ class Hive(Dialect):
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
- exp.LastDateOfMonth: rename_func("LAST_DAY"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5fe3d82..21a9657 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
isnull_to_is_null,
- json_keyvalue_comma_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@@ -21,6 +20,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
+ path_to_jsonpath,
rename_func,
strposition_to_locate_sql,
)
@@ -37,21 +37,21 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
- unit = expression.text("unit")
+ unit = expression.text("unit").upper()
- if unit == "day":
+ if unit == "DAY":
return f"DATE({expr})"
- if unit == "week":
+ if unit == "WEEK":
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
- elif unit == "month":
+ elif unit == "MONTH":
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
- elif unit == "quarter":
+ elif unit == "QUARTER":
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
- elif unit == "year":
+ elif unit == "YEAR":
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
@@ -292,9 +292,15 @@ class MySQL(Dialect):
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
+ "MAKETIME": exp.TimeFromParts.from_arg_list,
+ "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"MONTHNAME": lambda args: exp.TimeToStr(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
@@ -308,11 +314,6 @@ class MySQL(Dialect):
)
+ 1
),
- "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"WEEK": lambda args: exp.Week(
this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
),
@@ -441,6 +442,7 @@ class MySQL(Dialect):
}
LOG_DEFAULTS_TO_LN = True
+ STRING_ALIASES = True
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
@@ -620,13 +622,15 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
- NULL_ORDERING_SUPPORTED = False
+ NULL_ORDERING_SUPPORTED = None
JOIN_HINTS = False
TABLE_HINTS = True
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
+ LAST_DAY_SUPPORTS_DATE_PART = False
+ JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -642,15 +646,16 @@ class MySQL(Dialect):
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
+ exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
- exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
+ exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[
@@ -665,6 +670,7 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
+ exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 51dbd53..6ad3718 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -53,6 +53,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -81,6 +82,7 @@ class Oracle(Dialect):
"WW": "%W", # Week of year (1-53)
"YY": "%y", # 15
"YYYY": "%Y", # 2015
+ "FF6": "%f", # only 6 digits are supported in python formats
}
class Parser(parser.Parser):
@@ -91,6 +93,8 @@ class Oracle(Dialect):
**parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_CHAR": to_char,
+ "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"),
+ "TO_DATE": format_time_lambda(exp.StrToDate, "oracle"),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -107,6 +111,11 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
+ }
+
TYPE_LITERAL_PARSERS = {
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
exp.DateStrToDate, this=this
@@ -153,8 +162,10 @@ class Oracle(Dialect):
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
-
LIMIT_FETCH = "FETCH"
+ TABLESAMPLE_KEYWORDS = "SAMPLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
+ SUPPORTS_SELECT_INTO = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -186,6 +197,7 @@ class Oracle(Dialect):
]
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
@@ -201,6 +213,10 @@ class Oracle(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
+ this = expression.this
+ return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP"
+
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
@@ -233,8 +249,10 @@ class Oracle(Dialect):
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
+ "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
+ "SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index e274877..1ca0a78 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -13,11 +13,12 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
max_or_greatest,
+ merge_without_target_sql,
min_or_least,
+ no_last_day_sql,
no_map_from_entries_sql,
no_paren_current_date_sql,
no_pivot_sql,
- no_tablesample_sql,
no_trycast_sql,
parse_timestamp_trunc,
rename_func,
@@ -27,7 +28,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
trim_sql,
ts_or_ds_add_cast,
- ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
@@ -188,36 +188,6 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)
-def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
- def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
- """Remove table refs from columns in when statements."""
- if isinstance(expression, exp.Merge):
- alias = expression.this.args.get("alias")
-
- normalize = (
- lambda identifier: self.dialect.normalize_identifier(identifier).name
- if identifier
- else None
- )
-
- targets = {normalize(expression.this.this)}
-
- if alias:
- targets.add(normalize(alias.this))
-
- for when in expression.expressions:
- when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
- copy=False,
- )
-
- return expression
-
- return transforms.preprocess([_remove_target_from_merge])(self, expression)
-
-
class Postgres(Dialect):
INDEX_OFFSET = 1
TYPED_DIVISION = True
@@ -316,6 +286,8 @@ class Postgres(Dialect):
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
+ "MAKE_TIME": exp.TimeFromParts.from_arg_list,
+ "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
@@ -387,12 +359,18 @@ class Postgres(Dialect):
class Generator(generator.Generator):
SINGLE_STRING_INTERVAL = True
+ RENAME_TABLE_WITH_DB = False
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"
+ TABLESAMPLE_SIZE_IS_ROWS = False
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ SUPPORTS_SELECT_INTO = True
+ # https://www.postgresql.org/docs/current/sql-createtable.html
+ SUPPORTS_UNLOGGED_TABLES = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -430,12 +408,13 @@ class Postgres(Dialect):
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
+ exp.LastDay: no_last_day_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
- exp.Merge: _merge_sql,
+ exp.Merge: merge_without_target_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
@@ -458,16 +437,16 @@ class Postgres(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
+ exp.TimeFromParts: rename_func("MAKE_TIME"),
+ exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TableSample: no_tablesample_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("+"),
exp.TsOrDsDiff: _date_diff_sql,
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 360ab65..9b421e7 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
+ path_to_jsonpath,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
- unit = exp.Literal.string(expression.text("unit") or "day")
+ unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
- unit = exp.Literal.string(expression.text("unit") or "day")
+ unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_DIFF", unit, expr, this)
@@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
-def _parse_element_at(args: t.List) -> exp.Bracket:
- this = seq_get(args, 0)
- index = seq_get(args, 1)
- assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
- return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
-
-
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
timestamp = self.sql(expression, "this")
if scale in (None, exp.UnixToTime.SECONDS):
return rename_func("FROM_UNIXTIME")(self, expression)
- if scale == exp.UnixToTime.MILLIS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
- if scale == exp.UnixToTime.MICROS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
- if scale == exp.UnixToTime.NANOS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
def _to_int(expression: exp.Expression) -> exp.Expression:
@@ -215,6 +202,7 @@ class Presto(Dialect):
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@@ -258,7 +246,9 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
- "ELEMENT_AT": _parse_element_at,
+ "ELEMENT_AT": lambda args: exp.Bracket(
+ this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True
+ ),
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@@ -344,20 +334,20 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "day"),
+ exp.Literal.string(e.text("unit") or "DAY"),
_to_int(
e.expression,
),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "day"),
+ exp.Literal.string(e.text("unit") or "DAY"),
_to_int(e.expression * -1),
e.this,
),
@@ -366,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
+ exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@@ -376,6 +367,7 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Last: _first_last_sql,
+ exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@@ -446,7 +438,7 @@ class Presto(Dialect):
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
+ if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)
@@ -454,8 +446,8 @@ class Presto(Dialect):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
- if expression.this and unit.lower().startswith("week"):
- return f"({expression.this.name} * INTERVAL '7' day)"
+ if expression.this and unit.startswith("WEEK"):
+ return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str:
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 7382e7c..7194d81 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -9,8 +9,8 @@ from sqlglot.dialects.dialect import (
concat_ws_to_dpipe_sql,
date_delta_sql,
generatedasidentitycolumnconstraint_sql,
+ no_tablesample_sql,
rename_func,
- ts_or_ds_to_date_sql,
)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
@@ -123,6 +123,27 @@ class Redshift(Postgres):
self._retreat(index)
return None
+ def _parse_query_modifiers(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_query_modifiers(this)
+
+ if this:
+ refs = set()
+
+ for i, join in enumerate(this.args.get("joins", [])):
+ refs.add(
+ (
+ this.args["from"] if i == 0 else this.args["joins"][i - 1]
+ ).alias_or_name.lower()
+ )
+ table = join.this
+
+ if isinstance(table, exp.Table):
+ if table.parts[0].name.lower() in refs:
+ table.replace(table.to_column())
+ return this
+
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
@@ -144,11 +165,11 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
- RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = True
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -184,9 +205,9 @@ class Redshift(Postgres):
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
@@ -198,6 +219,9 @@ class Redshift(Postgres):
# Redshift supports ANY_VALUE(..)
TRANSFORMS.pop(exp.AnyValue)
+ # Redshift supports LAST_DAY(..)
+ TRANSFORMS.pop(exp.LastDay)
+
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def with_properties(self, properties: exp.Properties) -> str:
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 8925181..a8e4a42 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
rename_func,
timestamptrunc_sql,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
@@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
return format_time_lambda(exp.StrToTime, "snowflake")(args)
-
- # case: <numeric_expr> [ , <scale> ]
- if second_arg.name not in ["0", "3", "9"]:
- raise ValueError(
- f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
- )
-
- if second_arg.name == "0":
- timescale = exp.UnixToTime.SECONDS
- elif second_arg.name == "3":
- timescale = exp.UnixToTime.MILLIS
- elif second_arg.name == "9":
- timescale = exp.UnixToTime.NANOS
-
- return exp.UnixToTime(this=first_arg, scale=timescale)
+ return exp.UnixToTime(this=first_arg, scale=second_arg)
from sqlglot.optimizer.simplify import simplify_literals
@@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
def _parse_datediff(args: t.List) -> exp.DateDiff:
- return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-
-
-def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
- scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
- if scale in (None, exp.UnixToTime.SECONDS):
- return f"TO_TIMESTAMP({timestamp})"
- if scale == exp.UnixToTime.MILLIS:
- return f"TO_TIMESTAMP({timestamp}, 3)"
- if scale == exp.UnixToTime.MICROS:
- return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
- if scale == exp.UnixToTime.NANOS:
- return f"TO_TIMESTAMP({timestamp}, 9)"
-
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return exp.DateDiff(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
+ )
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
@@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
-
+ this = _map_date_part(this)
name = this.name.upper()
+
if name.startswith("EPOCH"):
- if name.startswith("EPOCH_MILLISECOND"):
+ if name == "EPOCH_MILLISECOND":
scale = 10**3
- elif name.startswith("EPOCH_MICROSECOND"):
+ elif name == "EPOCH_MICROSECOND":
scale = 10**6
- elif name.startswith("EPOCH_NANOSECOND"):
+ elif name == "EPOCH_NANOSECOND":
scale = 10**9
else:
scale = None
@@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
return _parse
+DATE_PART_MAPPING = {
+ "Y": "YEAR",
+ "YY": "YEAR",
+ "YYY": "YEAR",
+ "YYYY": "YEAR",
+ "YR": "YEAR",
+ "YEARS": "YEAR",
+ "YRS": "YEAR",
+ "MM": "MONTH",
+ "MON": "MONTH",
+ "MONS": "MONTH",
+ "MONTHS": "MONTH",
+ "D": "DAY",
+ "DD": "DAY",
+ "DAYS": "DAY",
+ "DAYOFMONTH": "DAY",
+ "WEEKDAY": "DAYOFWEEK",
+ "DOW": "DAYOFWEEK",
+ "DW": "DAYOFWEEK",
+ "WEEKDAY_ISO": "DAYOFWEEKISO",
+ "DOW_ISO": "DAYOFWEEKISO",
+ "DW_ISO": "DAYOFWEEKISO",
+ "YEARDAY": "DAYOFYEAR",
+ "DOY": "DAYOFYEAR",
+ "DY": "DAYOFYEAR",
+ "W": "WEEK",
+ "WK": "WEEK",
+ "WEEKOFYEAR": "WEEK",
+ "WOY": "WEEK",
+ "WY": "WEEK",
+ "WEEK_ISO": "WEEKISO",
+ "WEEKOFYEARISO": "WEEKISO",
+ "WEEKOFYEAR_ISO": "WEEKISO",
+ "Q": "QUARTER",
+ "QTR": "QUARTER",
+ "QTRS": "QUARTER",
+ "QUARTERS": "QUARTER",
+ "H": "HOUR",
+ "HH": "HOUR",
+ "HR": "HOUR",
+ "HOURS": "HOUR",
+ "HRS": "HOUR",
+ "M": "MINUTE",
+ "MI": "MINUTE",
+ "MIN": "MINUTE",
+ "MINUTES": "MINUTE",
+ "MINS": "MINUTE",
+ "S": "SECOND",
+ "SEC": "SECOND",
+ "SECONDS": "SECOND",
+ "SECS": "SECOND",
+ "MS": "MILLISECOND",
+ "MSEC": "MILLISECOND",
+ "MILLISECONDS": "MILLISECOND",
+ "US": "MICROSECOND",
+ "USEC": "MICROSECOND",
+ "MICROSECONDS": "MICROSECOND",
+ "NS": "NANOSECOND",
+ "NSEC": "NANOSECOND",
+ "NANOSEC": "NANOSECOND",
+ "NSECOND": "NANOSECOND",
+ "NSECONDS": "NANOSECOND",
+ "NANOSECS": "NANOSECOND",
+ "NSECONDS": "NANOSECOND",
+ "EPOCH": "EPOCH_SECOND",
+ "EPOCH_SECONDS": "EPOCH_SECOND",
+ "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
+ "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
+ "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
+ "TZH": "TIMEZONE_HOUR",
+ "TZM": "TIMEZONE_MINUTE",
+}
+
+
+@t.overload
+def _map_date_part(part: exp.Expression) -> exp.Var:
+ pass
+
+
+@t.overload
+def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ pass
+
+
+def _map_date_part(part):
+ mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
+ return exp.var(mapped) if mapped else part
+
+
+def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
+ trunc = date_trunc_to_time(args)
+ trunc.set("unit", _map_date_part(trunc.args["unit"]))
+ return trunc
+
+
+def _parse_colon_get_path(
+ self: parser.Parser, this: t.Optional[exp.Expression]
+) -> t.Optional[exp.Expression]:
+ while True:
+ path = self._parse_bitwise()
+
+ # The cast :: operator has a lower precedence than the extraction operator :, so
+ # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
+ if isinstance(path, exp.Cast):
+ target_type = path.to
+ path = path.this
+ else:
+ target_type = None
+
+ if isinstance(path, exp.Expression):
+ path = exp.Literal.string(path.sql(dialect="snowflake"))
+
+ # The extraction operator : is left-associative
+ this = self.expression(exp.GetPath, this=this, expression=path)
+
+ if target_type:
+ this = exp.cast(this, target_type)
+
+ if not self._match(TokenType.COLON):
+ break
+
+ if self._match_set(self.RANGE_PARSERS):
+ this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
+
+ return this
+
+
+def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
+ if len(args) == 2:
+ # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
+ # so we parse this into Anonymous for now instead of introducing complexity
+ return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
+
+ return exp.TimestampFromParts.from_arg_list(args)
+
+
+def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
+ so we need to unqualify them.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))")
+ >>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake"))
+ SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april))
+ """
+ if isinstance(expression, exp.Pivot) and expression.unpivot:
+ expression = transforms.unqualify_columns(expression)
+
+ return expression
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -211,6 +336,8 @@ class Snowflake(Dialect):
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
+ PREFER_CTE_ALIAS_COLUMN = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
TIME_MAPPING = {
"YYYY": "%Y",
@@ -276,14 +403,19 @@ class Snowflake(Dialect):
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
"CONVERT_TIMEZONE": _parse_convert_timezone,
- "DATE_TRUNC": date_trunc_to_time,
+ "DATE_TRUNC": _date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=_map_date_part(seq_get(args, 0)),
),
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"FLATTEN": exp.Explode.from_arg_list,
"IFF": exp.If.from_arg_list,
+ "LAST_DAY": lambda args: exp.LastDay(
+ this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
+ ),
"LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
@@ -293,6 +425,8 @@ class Snowflake(Dialect):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
+ "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
+ "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
@@ -301,22 +435,17 @@ class Snowflake(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
+ "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
}
FUNCTION_PARSERS.pop("TRIM")
- COLUMN_OPERATORS = {
- **parser.Parser.COLUMN_OPERATORS,
- TokenType.COLON: lambda self, this, path: self.expression(
- exp.Bracket, this=this, expressions=[path]
- ),
- }
-
TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
+ TokenType.COLON: _parse_colon_get_path,
}
ALTER_PARSERS = {
@@ -344,6 +473,7 @@ class Snowflake(Dialect):
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ "COLUMNS": _show_parser("COLUMNS"),
}
STAGED_FILE_SINGLE_TOKENS = {
@@ -351,8 +481,18 @@ class Snowflake(Dialect):
TokenType.MOD,
TokenType.SLASH,
}
+
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
+ def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
+ if is_map:
+ # Keys are strings in Snowflake's objects, see also:
+ # - https://docs.snowflake.com/en/sql-reference/data-types-semistructured
+ # - https://docs.snowflake.com/en/sql-reference/functions/object_construct
+ return self._parse_slice(self._parse_string())
+
+ return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
+
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
lateral = super()._parse_lateral()
if not lateral:
@@ -440,6 +580,8 @@ class Snowflake(Dialect):
scope = None
scope_kind = None
+ like = self._parse_string() if self._match(TokenType.LIKE) else None
+
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
@@ -451,7 +593,9 @@ class Snowflake(Dialect):
scope_kind = "TABLE"
scope = self._parse_table()
- return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
+ return self.expression(
+ exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
+ )
def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
@@ -489,8 +633,12 @@ class Snowflake(Dialect):
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
+ "REMOVE": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
+ "RM": TokenType.COMMAND,
"SAMPLE": TokenType.TABLE_SAMPLE,
+ "SQL_DOUBLE": TokenType.DOUBLE,
+ "SQL_VARCHAR": TokenType.VARCHAR,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -518,6 +666,8 @@ class Snowflake(Dialect):
SUPPORTS_TABLE_COPY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
+ JSON_KEY_VALUE_PAIR_SEP = ","
+ INSERT_OVERWRITE = " OVERWRITE INTO"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -545,6 +695,8 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
+ exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
+ exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -557,6 +709,7 @@ class Snowflake(Dialect):
exp.PercentileDisc: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
+ exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]),
exp.RegexpILike: _regexpilike_sql,
exp.Rand: rename_func("RANDOM"),
exp.Select: transforms.preprocess(
@@ -578,6 +731,9 @@ class Snowflake(Dialect):
*(arg for expression in e.expressions for arg in expression.flatten()),
),
exp.Stuff: rename_func("INSERT"),
+ exp.TimestampDiff: lambda self, e: self.func(
+ "TIMESTAMPDIFF", e.unit, e.expression, e.this
+ ),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@@ -589,8 +745,7 @@ class Snowflake(Dialect):
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
- exp.UnixToTime: _unix_to_time_sql,
+ exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.Xor: rename_func("BOOLXOR"),
@@ -612,6 +767,14 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ milli = expression.args.get("milli")
+ if milli is not None:
+ milli_to_nano = milli.pop() * exp.Literal.number(1000000)
+ expression.set("nano", milli_to_nano)
+
+ return rename_func("TIMESTAMP_FROM_PARTS")(self, expression)
+
def trycast_sql(self, expression: exp.TryCast) -> str:
value = expression.this
@@ -657,6 +820,9 @@ class Snowflake(Dialect):
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
+ like = self.sql(expression, "like")
+ like = f" LIKE {like}" if like else ""
+
scope = self.sql(expression, "scope")
scope = f" {scope}" if scope else ""
@@ -664,7 +830,7 @@ class Snowflake(Dialect):
if scope_kind:
scope_kind = f" IN {scope_kind}"
- return f"SHOW {expression.name}{scope_kind}{scope}"
+ return f"SHOW {expression.name}{like}{scope_kind}{scope}"
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index aa09f53..e27ba18 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -48,11 +48,8 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
- if scale == exp.UnixToTime.NANOS:
- return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
@@ -93,12 +90,7 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
"""
if isinstance(expression, exp.Pivot):
- expression.args["field"].transform(
- lambda node: exp.column(node.output_name, quoted=node.this.quoted)
- if isinstance(node, exp.Column)
- else node,
- copy=False,
- )
+ expression.set("field", transforms.unqualify_columns(expression.args["field"]))
return expression
@@ -234,7 +226,7 @@ class Spark2(Hive):
def struct_sql(self, expression: exp.Struct) -> str:
args = []
for arg in expression.expressions:
- if isinstance(arg, self.KEY_VALUE_DEFINITONS):
+ if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
if isinstance(arg, exp.Bracket):
args.append(exp.alias_(arg.this, arg.expressions[0].name))
else:
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 9bac51c..244a96e 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -78,6 +78,7 @@ class SQLite(Dialect):
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
+ STRING_ALIASES = True
class Generator(generator.Generator):
JOIN_HINTS = False
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 0ccc567..6dbad15 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -175,6 +175,8 @@ class Teradata(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ TABLESAMPLE_KEYWORDS = "SAMPLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -214,7 +216,10 @@ class Teradata(Dialect):
return self.cast_sql(expression, safe_prefix="TRY")
def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ self,
+ expression: exp.TableSample,
+ sep: str = " AS ",
+ tablesample_keyword: t.Optional[str] = None,
) -> str:
return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}"
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 3682ac7..eddb70a 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from sqlglot import exp
+from sqlglot.dialects.dialect import merge_without_target_sql
from sqlglot.dialects.presto import Presto
@@ -11,6 +12,7 @@ class Trino(Presto):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.Merge: merge_without_target_sql,
}
class Tokenizer(Presto.Tokenizer):
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 165a703..b9c347c 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
parse_date_delta,
+ path_to_jsonpath,
rename_func,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
+ trim_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(this=this, format=fmt, culture=culture)
-def _parse_eomonth(args: t.List) -> exp.Expression:
- date = seq_get(args, 0)
+def _parse_eomonth(args: t.List) -> exp.LastDay:
+ date = exp.TsOrDsToDate(this=seq_get(args, 0))
month_lag = seq_get(args, 1)
- unit = DATE_DELTA_INTERVAL.get("month")
if month_lag is None:
- return exp.LastDateOfMonth(this=date)
+ this: exp.Expression = date
+ else:
+ unit = DATE_DELTA_INTERVAL.get("month")
+ this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit))
- # Remove month lag argument in parser as its compared with the number of arguments of the resulting class
- args.remove(month_lag)
-
- return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
+ return exp.LastDay(this=this)
def _parse_hashbytes(args: t.List) -> exp.Expression:
@@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
return exp.func("HASHBYTES", *args)
-DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
+DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"}
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
- fmt = (
- expression.args["format"]
- if isinstance(expression, exp.NumberToStr)
- else exp.Literal.string(
- format_time(
- expression.text("format"),
- t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
- )
- )
- )
+ fmt = expression.args["format"]
- # There is no format for "quarter"
- if fmt.name.lower() in DATEPART_ONLY_FORMATS:
- return self.func("DATEPART", fmt.name, expression.this)
+ if not isinstance(expression, exp.NumberToStr):
+ if fmt.is_string:
+ mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
- return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
+ name = (mapped_fmt or "").upper()
+ if name in DATEPART_ONLY_FORMATS:
+ return self.func("DATEPART", name, expression.this)
+
+ fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
+ else:
+ fmt_sql = self.format_time(expression) or self.sql(fmt)
+ else:
+ fmt_sql = self.sql(fmt)
+
+ return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture"))
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
@@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
return expression
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
+def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
+ return exp.TimestampFromParts(
+ year=seq_get(args, 0),
+ month=seq_get(args, 1),
+ day=seq_get(args, 2),
+ hour=seq_get(args, 3),
+ min=seq_get(args, 4),
+ sec=seq_get(args, 5),
+ milli=seq_get(args, 6),
+ )
+
+
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
+def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
+ return exp.TimeFromParts(
+ hour=seq_get(args, 0),
+ min=seq_get(args, 1),
+ sec=seq_get(args, 2),
+ fractions=seq_get(args, 3),
+ precision=seq_get(args, 4),
+ )
+
+
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@@ -352,7 +377,7 @@ class TSQL(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- IDENTIFIERS = ['"', ("[", "]")]
+ IDENTIFIERS = [("[", "]"), '"']
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
VAR_SINGLE_TOKENS = {"@", "$", "#"}
@@ -362,6 +387,7 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.COMMAND,
+ "EXEC": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
@@ -397,6 +423,7 @@ class TSQL(Dialect):
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
+ "DATETIMEFROMPARTS": _parse_datetimefromparts,
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
@@ -411,6 +438,7 @@ class TSQL(Dialect):
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
+ "TIMEFROMPARTS": _parse_timefromparts,
}
JOIN_HINTS = {
@@ -440,6 +468,7 @@ class TSQL(Dialect):
LOG_DEFAULTS_TO_LN = True
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
+ STRING_ALIASES = True
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -630,8 +659,10 @@ class TSQL(Dialect):
COMPUTED_COLUMN_WITH_TYPE = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
ENSURE_BOOLS = True
- NULL_ORDERING_SUPPORTED = False
+ NULL_ORDERING_SUPPORTED = None
SUPPORTS_SINGLE_ARG_CONCAT = False
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ SUPPORTS_SELECT_INTO = True
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -667,13 +698,16 @@ class TSQL(Dialect):
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
+ exp.GetPath: path_to_jsonpath("JSON_VALUE"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
+ exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
+ exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
@@ -689,9 +723,9 @@ class TSQL(Dialect):
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.Trim: trim_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
TRANSFORMS.pop(exp.ReturnsProperty)
@@ -701,6 +735,46 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def lateral_op(self, expression: exp.Lateral) -> str:
+ cross_apply = expression.args.get("cross_apply")
+ if cross_apply is True:
+ return "CROSS APPLY"
+ if cross_apply is False:
+ return "OUTER APPLY"
+
+ # TODO: perhaps we can check if the parent is a Join and transpile it appropriately
+ self.unsupported("LATERAL clause is not supported.")
+ return "LATERAL"
+
+ def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
+ nano = expression.args.get("nano")
+ if nano is not None:
+ nano.pop()
+ self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.")
+
+ if expression.args.get("fractions") is None:
+ expression.set("fractions", exp.Literal.number(0))
+ if expression.args.get("precision") is None:
+ expression.set("precision", exp.Literal.number(0))
+
+ return rename_func("TIMEFROMPARTS")(self, expression)
+
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ zone = expression.args.get("zone")
+ if zone is not None:
+ zone.pop()
+ self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.")
+
+ nano = expression.args.get("nano")
+ if nano is not None:
+ nano.pop()
+ self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.")
+
+ if expression.args.get("milli") is None:
+ expression.set("milli", exp.Literal.number(0))
+
+ return rename_func("DATETIMEFROMPARTS")(self, expression)
+
def set_operation(self, expression: exp.Union, op: str) -> str:
limit = expression.args.get("limit")
if limit: