summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-23 05:06:14 +0000
commit38e6461a8afbd7cb83709ddb998f03d40ba87755 (patch)
tree64b68a893a3b946111b9cab69503f83ca233c335 /sqlglot
parentReleasing debian version 20.4.0-1. (diff)
downloadsqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.tar.xz
sqlglot-38e6461a8afbd7cb83709ddb998f03d40ba87755.zip
Merging upstream version 20.9.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py6
-rw-r--r--sqlglot/dialects/bigquery.py91
-rw-r--r--sqlglot/dialects/clickhouse.py201
-rw-r--r--sqlglot/dialects/databricks.py2
-rw-r--r--sqlglot/dialects/dialect.py119
-rw-r--r--sqlglot/dialects/doris.py9
-rw-r--r--sqlglot/dialects/drill.py3
-rw-r--r--sqlglot/dialects/duckdb.py121
-rw-r--r--sqlglot/dialects/hive.py3
-rw-r--r--sqlglot/dialects/mysql.py34
-rw-r--r--sqlglot/dialects/oracle.py20
-rw-r--r--sqlglot/dialects/postgres.py49
-rw-r--r--sqlglot/dialects/presto.py40
-rw-r--r--sqlglot/dialects/redshift.py30
-rw-r--r--sqlglot/dialects/snowflake.py266
-rw-r--r--sqlglot/dialects/spark2.py14
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/dialects/teradata.py7
-rw-r--r--sqlglot/dialects/trino.py2
-rw-r--r--sqlglot/dialects/tsql.py128
-rw-r--r--sqlglot/executor/env.py6
-rw-r--r--sqlglot/expressions.py346
-rw-r--r--sqlglot/generator.py259
-rw-r--r--sqlglot/lineage.py7
-rw-r--r--sqlglot/optimizer/annotate_types.py24
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py50
-rw-r--r--sqlglot/optimizer/pushdown_projections.py7
-rw-r--r--sqlglot/optimizer/qualify.py16
-rw-r--r--sqlglot/optimizer/qualify_columns.py170
-rw-r--r--sqlglot/optimizer/qualify_tables.py6
-rw-r--r--sqlglot/optimizer/scope.py16
-rw-r--r--sqlglot/optimizer/simplify.py23
-rw-r--r--sqlglot/parser.py315
-rw-r--r--sqlglot/schema.py5
-rw-r--r--sqlglot/tokens.py4
-rw-r--r--sqlglot/transforms.py2
36 files changed, 1843 insertions, 559 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 6658287..141a302 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -574,13 +574,13 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(
- col, expression.DateAdd, expression=days, unit=expression.Var(this="day")
+ col, expression.DateAdd, expression=days, unit=expression.Var(this="DAY")
)
def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_expression_over_column(
- col, expression.DateSub, expression=days, unit=expression.Var(this="day")
+ col, expression.DateSub, expression=days, unit=expression.Var(this="DAY")
)
@@ -635,7 +635,7 @@ def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
def last_day(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "LAST_DAY")
+ return Column.invoke_expression_over_column(col, expression.LastDay)
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 7a573e7..0151e6c 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -16,20 +16,22 @@ from sqlglot.dialects.dialect import (
format_time_lambda,
if_sql,
inline_array_sql,
- json_keyvalue_comma_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
parse_date_delta_with_interval,
+ path_to_jsonpath,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
ts_or_ds_add_cast,
- ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from typing_extensions import Literal
+
logger = logging.getLogger("sqlglot")
@@ -206,12 +208,17 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
- if scale == exp.UnixToTime.NANOS:
- # We need to cast to INT64 because that's what BQ expects
- return f"TIMESTAMP_MICROS(CAST({timestamp} / 1000 AS INT64))"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))"
+
+
+def _parse_time(args: t.List) -> exp.Func:
+ if len(args) == 1:
+ return exp.TsOrDsToTime(this=args[0])
+ if len(args) == 3:
+ return exp.TimeFromParts.from_arg_list(args)
+
+ return exp.Anonymous(this="TIME", expressions=args)
class BigQuery(Dialect):
@@ -329,7 +336,13 @@ class BigQuery(Dialect):
"DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
+ "FORMAT_DATE": lambda args: exp.TimeToStr(
+ this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0)
+ ),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
+ "JSON_EXTRACT_SCALAR": lambda args: exp.JSONExtractScalar(
+ this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
+ ),
"MD5": exp.MD5Digest.from_arg_list,
"TO_HEX": _parse_to_hex,
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
@@ -351,6 +364,7 @@ class BigQuery(Dialect):
this=seq_get(args, 0),
expression=seq_get(args, 1) or exp.Literal.string(","),
),
+ "TIME": _parse_time,
"TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
"TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
"TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
@@ -361,9 +375,7 @@ class BigQuery(Dialect):
"TIMESTAMP_MILLIS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
- "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(
- this=seq_get(args, 0), scale=exp.UnixToTime.SECONDS
- ),
+ "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)),
"TO_JSON_STRING": exp.JSONFormat.from_arg_list,
}
@@ -460,7 +472,15 @@ class BigQuery(Dialect):
return table
- def _parse_json_object(self) -> exp.JSONObject:
+ @t.overload
+ def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
+ ...
+
+ @t.overload
+ def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
+ ...
+
+ def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
array_kv_pair = seq_get(json_object.expressions, 0)
@@ -513,6 +533,10 @@ class BigQuery(Dialect):
UNNEST_WITH_ORDINALITY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
+ SUPPORTS_TABLE_ALIAS_COLUMNS = False
+ UNPIVOT_ALIASES_ARE_IDENTIFIERS = False
+ JSON_KEY_VALUE_PAIR_SEP = ","
+ NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -525,6 +549,7 @@ class BigQuery(Dialect):
exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}",
+ exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
@@ -536,13 +561,13 @@ class BigQuery(Dialect):
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
+ exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
- exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
@@ -578,16 +603,17 @@ class BigQuery(Dialect):
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
exp.TimeAdd: date_add_interval_sql("TIME", "ADD"),
+ exp.TimeFromParts: rename_func("TIME"),
exp.TimeSub: date_add_interval_sql("TIME", "SUB"),
exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeToStr: lambda self, e: f"FORMAT_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
+ exp.TsOrDsToTime: rename_func("TIME"),
exp.Unhex: rename_func("FROM_HEX"),
+ exp.UnixDate: rename_func("UNIX_DATE"),
exp.UnixToTime: _unix_to_time_sql,
exp.Values: _derived_table_values_to_unnest,
exp.VariancePop: rename_func("VAR_POP"),
@@ -724,6 +750,26 @@ class BigQuery(Dialect):
"within",
}
+ def timetostr_sql(self, expression: exp.TimeToStr) -> str:
+ if isinstance(expression.this, exp.TsOrDsToDate):
+ this: exp.Expression = expression.this
+ else:
+ this = expression
+
+ return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})"
+
+ def struct_sql(self, expression: exp.Struct) -> str:
+ args = []
+ for expr in expression.expressions:
+ if isinstance(expr, self.KEY_VALUE_DEFINITIONS):
+ arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}"
+ else:
+ arg = self.sql(expr)
+
+ args.append(arg)
+
+ return self.func("STRUCT", *args)
+
def eq_sql(self, expression: exp.EQ) -> str:
# Operands of = cannot be NULL in BigQuery
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
@@ -760,7 +806,20 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
+ this = self.sql(expression, "this")
expressions = expression.expressions
+
+ if len(expressions) == 1:
+ arg = expressions[0]
+ if arg.type is None:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ arg = annotate_types(arg)
+
+ if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
+ # BQ doesn't support bracket syntax with string values
+ return f"{this}.{arg.name}"
+
expressions_sql = ", ".join(self.sql(e) for e in expressions)
offset = expression.args.get("offset")
@@ -768,13 +827,13 @@ class BigQuery(Dialect):
expressions_sql = f"OFFSET({expressions_sql})"
elif offset == 1:
expressions_sql = f"ORDINAL({expressions_sql})"
- else:
+ elif offset is not None:
self.unsupported(f"Unsupported array offset: {offset}")
if expression.args.get("safe"):
expressions_sql = f"SAFE_{expressions_sql}"
- return f"{self.sql(expression, 'this')}[{expressions_sql}]"
+ return f"{this}[{expressions_sql}]"
def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 870f402..f2e4fe1 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arg_max_or_min_no_count,
+ date_delta_sql,
inline_array_sql,
no_pivot_sql,
rename_func,
@@ -22,16 +23,25 @@ def _lower_func(sql: str) -> str:
return sql[:index].lower() + sql[index:]
-def _quantile_sql(self, e):
+def _quantile_sql(self: ClickHouse.Generator, e: exp.Quantile) -> str:
quantile = e.args["quantile"]
args = f"({self.sql(e, 'this')})"
+
if isinstance(quantile, exp.Array):
func = self.func("quantiles", *quantile)
else:
func = self.func("quantile", quantile)
+
return func + args
+def _parse_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc:
+ if len(args) == 1:
+ return exp.CountIf(this=seq_get(args, 0))
+
+ return exp.CombinedAggFunc(this="countIf", expressions=args, parts=("count", "If"))
+
+
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
@@ -53,6 +63,7 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ATTACH": TokenType.COMMAND,
+ "DATE32": TokenType.DATE32,
"DATETIME64": TokenType.DATETIME64,
"DICTIONARY": TokenType.DICTIONARY,
"ENUM": TokenType.ENUM,
@@ -75,6 +86,8 @@ class ClickHouse(Dialect):
"UINT32": TokenType.UINT,
"UINT64": TokenType.UBIGINT,
"UINT8": TokenType.UTINYINT,
+ "IPV4": TokenType.IPV4,
+ "IPV6": TokenType.IPV6,
}
SINGLE_TOKENS = {
@@ -91,6 +104,8 @@ class ClickHouse(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ANY": exp.AnyValue.from_arg_list,
+ "ARRAYSUM": exp.ArraySum.from_arg_list,
+ "COUNTIF": _parse_count_if,
"DATE_ADD": lambda args: exp.DateAdd(
this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
),
@@ -110,6 +125,138 @@ class ClickHouse(Dialect):
"XOR": lambda args: exp.Xor(expressions=args),
}
+ AGG_FUNCTIONS = {
+ "count",
+ "min",
+ "max",
+ "sum",
+ "avg",
+ "any",
+ "stddevPop",
+ "stddevSamp",
+ "varPop",
+ "varSamp",
+ "corr",
+ "covarPop",
+ "covarSamp",
+ "entropy",
+ "exponentialMovingAverage",
+ "intervalLengthSum",
+ "kolmogorovSmirnovTest",
+ "mannWhitneyUTest",
+ "median",
+ "rankCorr",
+ "sumKahan",
+ "studentTTest",
+ "welchTTest",
+ "anyHeavy",
+ "anyLast",
+ "boundingRatio",
+ "first_value",
+ "last_value",
+ "argMin",
+ "argMax",
+ "avgWeighted",
+ "topK",
+ "topKWeighted",
+ "deltaSum",
+ "deltaSumTimestamp",
+ "groupArray",
+ "groupArrayLast",
+ "groupUniqArray",
+ "groupArrayInsertAt",
+ "groupArrayMovingAvg",
+ "groupArrayMovingSum",
+ "groupArraySample",
+ "groupBitAnd",
+ "groupBitOr",
+ "groupBitXor",
+ "groupBitmap",
+ "groupBitmapAnd",
+ "groupBitmapOr",
+ "groupBitmapXor",
+ "sumWithOverflow",
+ "sumMap",
+ "minMap",
+ "maxMap",
+ "skewSamp",
+ "skewPop",
+ "kurtSamp",
+ "kurtPop",
+ "uniq",
+ "uniqExact",
+ "uniqCombined",
+ "uniqCombined64",
+ "uniqHLL12",
+ "uniqTheta",
+ "quantile",
+ "quantiles",
+ "quantileExact",
+ "quantilesExact",
+ "quantileExactLow",
+ "quantilesExactLow",
+ "quantileExactHigh",
+ "quantilesExactHigh",
+ "quantileExactWeighted",
+ "quantilesExactWeighted",
+ "quantileTiming",
+ "quantilesTiming",
+ "quantileTimingWeighted",
+ "quantilesTimingWeighted",
+ "quantileDeterministic",
+ "quantilesDeterministic",
+ "quantileTDigest",
+ "quantilesTDigest",
+ "quantileTDigestWeighted",
+ "quantilesTDigestWeighted",
+ "quantileBFloat16",
+ "quantilesBFloat16",
+ "quantileBFloat16Weighted",
+ "quantilesBFloat16Weighted",
+ "simpleLinearRegression",
+ "stochasticLinearRegression",
+ "stochasticLogisticRegression",
+ "categoricalInformationValue",
+ "contingency",
+ "cramersV",
+ "cramersVBiasCorrected",
+ "theilsU",
+ "maxIntersections",
+ "maxIntersectionsPosition",
+ "meanZTest",
+ "quantileInterpolatedWeighted",
+ "quantilesInterpolatedWeighted",
+ "quantileGK",
+ "quantilesGK",
+ "sparkBar",
+ "sumCount",
+ "largestTriangleThreeBuckets",
+ }
+
+ AGG_FUNCTIONS_SUFFIXES = [
+ "If",
+ "Array",
+ "ArrayIf",
+ "Map",
+ "SimpleState",
+ "State",
+ "Merge",
+ "MergeState",
+ "ForEach",
+ "Distinct",
+ "OrDefault",
+ "OrNull",
+ "Resample",
+ "ArgMin",
+ "ArgMax",
+ ]
+
+ AGG_FUNC_MAPPING = (
+ lambda functions, suffixes: {
+ f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
+ }
+ )(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES)
+
FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"}
FUNCTION_PARSERS = {
@@ -272,9 +419,18 @@ class ClickHouse(Dialect):
)
if isinstance(func, exp.Anonymous):
+ parts = self.AGG_FUNC_MAPPING.get(func.this)
params = self._parse_func_params(func)
if params:
+ if parts and parts[1]:
+ return self.expression(
+ exp.CombinedParameterizedAgg,
+ this=func.this,
+ expressions=func.expressions,
+ params=params,
+ parts=parts,
+ )
return self.expression(
exp.ParameterizedAgg,
this=func.this,
@@ -282,6 +438,20 @@ class ClickHouse(Dialect):
params=params,
)
+ if parts:
+ if parts[1]:
+ return self.expression(
+ exp.CombinedAggFunc,
+ this=func.this,
+ expressions=func.expressions,
+ parts=parts,
+ )
+ return self.expression(
+ exp.AnonymousAggFunc,
+ this=func.this,
+ expressions=func.expressions,
+ )
+
return func
def _parse_func_params(
@@ -329,6 +499,9 @@ class ClickHouse(Dialect):
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
TABLESAMPLE_REQUIRES_PARENS = False
+ TABLESAMPLE_SIZE_IS_ROWS = False
+ TABLESAMPLE_KEYWORDS = "SAMPLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -348,6 +521,7 @@ class ClickHouse(Dialect):
**STRING_TYPE_MAPPING,
exp.DataType.Type.ARRAY: "Array",
exp.DataType.Type.BIGINT: "Int64",
+ exp.DataType.Type.DATE32: "Date32",
exp.DataType.Type.DATETIME64: "DateTime64",
exp.DataType.Type.DOUBLE: "Float64",
exp.DataType.Type.ENUM: "Enum",
@@ -372,24 +546,23 @@ class ClickHouse(Dialect):
exp.DataType.Type.UINT256: "UInt256",
exp.DataType.Type.USMALLINT: "UInt16",
exp.DataType.Type.UTINYINT: "UInt8",
+ exp.DataType.Type.IPV4: "IPv4",
+ exp.DataType.Type.IPV6: "IPv6",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
- exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
+ exp.ArraySum: rename_func("arraySum"),
exp.ArgMax: arg_max_or_min_no_count("argMax"),
exp.ArgMin: arg_max_or_min_no_count("argMin"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
+ exp.CountIf: rename_func("countIf"),
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
- exp.DateAdd: lambda self, e: self.func(
- "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
- ),
- exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
- ),
+ exp.DateAdd: date_delta_sql("DATE_ADD"),
+ exp.DateDiff: date_delta_sql("DATE_DIFF"),
exp.Explode: rename_func("arrayJoin"),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.IsNan: rename_func("isNaN"),
@@ -400,6 +573,7 @@ class ClickHouse(Dialect):
exp.Quantile: _quantile_sql,
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
exp.Rand: rename_func("randCanonical"),
+ exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
@@ -485,10 +659,19 @@ class ClickHouse(Dialect):
else "",
]
- def parameterizedagg_sql(self, expression: exp.Anonymous) -> str:
+ def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
params = self.expressions(expression, key="params", flat=True)
return self.func(expression.name, *expression.expressions) + f"({params})"
+ def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str:
+ return self.func(expression.name, *expression.expressions)
+
+ def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str:
+ return self.anonymousaggfunc_sql(expression)
+
+ def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str:
+ return self.parameterizedagg_sql(expression)
+
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 1c10a8b..8e55b6a 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -30,6 +30,8 @@ class Databricks(Spark):
}
class Generator(Spark.Generator):
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+
TRANSFORMS = {
**Spark.Generator.TRANSFORMS,
exp.DateAdd: date_delta_sql("DATEADD"),
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index b7eef45..7664c40 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -153,6 +153,9 @@ class Dialect(metaclass=_Dialect):
ALIAS_POST_TABLESAMPLE = False
"""Determines whether or not the table alias comes after tablesample."""
+ TABLESAMPLE_SIZE_IS_PERCENT = False
+ """Determines whether or not a size in the table sample clause represents percentage."""
+
NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE
"""Specifies the strategy according to which identifiers should be normalized."""
@@ -220,6 +223,24 @@ class Dialect(metaclass=_Dialect):
For example, such columns may be excluded from `SELECT *` queries.
"""
+ PREFER_CTE_ALIAS_COLUMN = False
+ """
+ Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
+ HAVING clause of the CTE. This flag will cause the CTE alias columns to override
+ any projection aliases in the subquery.
+
+ For example,
+ WITH y(c) AS (
+ SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
+ ) SELECT c FROM y;
+
+ will be rewritten as
+
+ WITH y(c) AS (
+ SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
+ ) SELECT c FROM y;
+ """
+
# --- Autofilled ---
tokenizer_class = Tokenizer
@@ -287,7 +308,13 @@ class Dialect(metaclass=_Dialect):
result = cls.get(dialect_name.strip())
if not result:
- raise ValueError(f"Unknown dialect '{dialect_name}'.")
+ from difflib import get_close_matches
+
+ similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or ""
+ if similar:
+ similar = f" Did you mean {similar}?"
+
+ raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}")
return result(**kwargs)
@@ -506,7 +533,7 @@ def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str:
def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str:
n = self.sql(expression, "this")
d = self.sql(expression, "expression")
- return f"IF({d} <> 0, {n} / {d}, NULL)"
+ return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)"
def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str:
@@ -695,7 +722,7 @@ def date_add_interval_sql(
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
return self.func(
- "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
+ "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
)
@@ -801,22 +828,6 @@ def str_to_time_sql(self: Generator, expression: exp.Expression) -> str:
return self.func("STRPTIME", expression.this, self.format_time(expression))
-def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
- def _ts_or_ds_to_date_sql(self: Generator, expression: exp.TsOrDsToDate) -> str:
- _dialect = Dialect.get_or_raise(dialect)
- time_format = self.format_time(expression)
- if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
- return self.sql(
- exp.cast(
- exp.StrToTime(this=expression.this, format=expression.args["format"]),
- "date",
- )
- )
- return self.sql(exp.cast(expression.this, "date"))
-
- return _ts_or_ds_to_date_sql
-
-
def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str:
return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions))
@@ -894,11 +905,6 @@ def bool_xor_sql(self: Generator, expression: exp.Xor) -> str:
return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})"
-# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon
-def json_keyvalue_comma_sql(self: Generator, expression: exp.JSONKeyValue) -> str:
- return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}"
-
-
def is_parse_json(expression: exp.Expression) -> bool:
return isinstance(expression, exp.ParseJSON) or (
isinstance(expression, exp.Cast) and expression.is_type("json")
@@ -946,7 +952,70 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
expression = ts_or_ds_add_cast(expression)
return self.func(
- name, exp.var(expression.text("unit") or "day"), expression.expression, expression.this
+ name,
+ exp.var(expression.text("unit").upper() or "DAY"),
+ expression.expression,
+ expression.this,
)
return _delta_sql
+
+
+def prepend_dollar_to_path(expression: exp.GetPath) -> exp.GetPath:
+ from sqlglot.optimizer.simplify import simplify
+
+ # Makes sure the path will be evaluated correctly at runtime to include the path root.
+ # For example, `[0].foo` will become `$[0].foo`, and `foo` will become `$.foo`.
+ path = expression.expression
+ path = exp.func(
+ "if",
+ exp.func("startswith", path, "'['"),
+ exp.func("concat", "'$'", path),
+ exp.func("concat", "'$.'", path),
+ )
+
+ expression.expression.replace(simplify(path))
+ return expression
+
+
+def path_to_jsonpath(
+ name: str = "JSON_EXTRACT",
+) -> t.Callable[[Generator, exp.GetPath], str]:
+ def _transform(self: Generator, expression: exp.GetPath) -> str:
+ return rename_func(name)(self, prepend_dollar_to_path(expression))
+
+ return _transform
+
+
+def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
+ trunc_curr_date = exp.func("date_trunc", "month", expression.this)
+ plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
+ minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
+
+ return self.sql(exp.cast(minus_one_day, "date"))
+
+
+def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
+ """Remove table refs from columns in when statements."""
+ alias = expression.this.args.get("alias")
+
+ normalize = (
+ lambda identifier: self.dialect.normalize_identifier(identifier).name
+ if identifier
+ else None
+ )
+
+ targets = {normalize(expression.this.this)}
+
+ if alias:
+ targets.add(normalize(alias.this))
+
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node,
+ copy=False,
+ )
+
+ return self.merge_sql(expression)
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 11af17b..6e229b3 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -22,6 +22,7 @@ class Doris(MySQL):
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_TRUNC": parse_timestamp_trunc,
"REGEXP": exp.RegexpLike.from_arg_list,
+ "TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
class Generator(MySQL.Generator):
@@ -34,21 +35,26 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
+ LAST_DAY_SUPPORTS_DATE_PART = False
+
TIMESTAMP_FUNC_TYPES = set()
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.ArgMax: rename_func("MAX_BY"),
+ exp.ArgMin: rename_func("MIN_BY"),
exp.ArrayAgg: rename_func("COLLECT_LIST"),
+ exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda *_: "NOW()",
exp.DateTrunc: lambda self, e: self.func(
"DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
+ exp.Map: rename_func("ARRAY_MAP"),
exp.RegexpLike: rename_func("REGEXP"),
exp.RegexpSplit: rename_func("SPLIT_BY_STRING"),
- exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Split: rename_func("SPLIT_BY_STRING"),
exp.TimeStrToDate: rename_func("TO_DATE"),
@@ -63,5 +69,4 @@ class Doris(MySQL):
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
- exp.Map: rename_func("ARRAY_MAP"),
}
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index c9b31a0..6bca9e7 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -12,7 +12,6 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
)
@@ -99,6 +98,7 @@ class Drill(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -150,7 +150,6 @@ class Drill(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index cd9d529..2343b35 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -22,15 +22,15 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_timestamp_sql,
pivot_column_names,
+ prepend_dollar_to_path,
regexp_extract_sql,
rename_func,
str_position_sql,
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
)
-from sqlglot.helper import seq_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@@ -141,11 +141,25 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
return f"EPOCH_MS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"MAKE_TIMESTAMP({timestamp})"
- if scale == exp.UnixToTime.NANOS:
- return f"TO_TIMESTAMP({timestamp} / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"TO_TIMESTAMP({timestamp} / POW(10, {scale}))"
+
+
+def _rename_unless_within_group(
+ a: str, b: str
+) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
+ return (
+ lambda self, expression: self.func(a, *flatten(expression.args.values()))
+ if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
+ else self.func(b, *flatten(expression.args.values()))
+ )
+
+
+def _parse_struct_pack(args: t.List) -> exp.Struct:
+ args_with_columns_as_identifiers = [
+ exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args
+ ]
+ return exp.Struct.from_arg_list(args_with_columns_as_identifiers)
class DuckDB(Dialect):
@@ -183,6 +197,11 @@ class DuckDB(Dialect):
"TIMESTAMP_US": TokenType.TIMESTAMP,
}
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.PARAMETER,
+ }
+
class Parser(parser.Parser):
BITWISE = {
**parser.Parser.BITWISE,
@@ -209,10 +228,12 @@ class DuckDB(Dialect):
"EPOCH_MS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
+ "JSON": exp.ParseJSON.from_arg_list,
"LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
"LIST_VALUE": exp.Array.from_arg_list,
+ "MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": _parse_make_timestamp,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
@@ -234,7 +255,7 @@ class DuckDB(Dialect):
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"STRING_TO_ARRAY": exp.Split.from_arg_list,
"STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
- "STRUCT_PACK": exp.Struct.from_arg_list,
+ "STRUCT_PACK": _parse_struct_pack,
"STR_SPLIT": exp.Split.from_arg_list,
"STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
@@ -250,6 +271,13 @@ class DuckDB(Dialect):
TokenType.ANTI,
}
+ PLACEHOLDER_PARSERS = {
+ **parser.Parser.PLACEHOLDER_PARSERS,
+ TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
+ else None,
+ }
+
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
@@ -268,7 +296,7 @@ class DuckDB(Dialect):
return this
- def _parse_struct_types(self) -> t.Optional[exp.Expression]:
+ def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
return self._parse_field_def()
def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]:
@@ -285,6 +313,10 @@ class DuckDB(Dialect):
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
SEMI_ANTI_JOIN_WITH_SIDE = False
+ TABLESAMPLE_KEYWORDS = "USING SAMPLE"
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
+ JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -311,7 +343,7 @@ class DuckDB(Dialect):
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateSub: _date_delta_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this
+ "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)",
@@ -322,11 +354,11 @@ class DuckDB(Dialect):
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
+ exp.JSONBExtract: arrow_json_extract_sql,
+ exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONFormat: _json_format_sql,
- exp.JSONBExtract: arrow_json_extract_sql,
- exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.MonthsBetween: lambda self, e: self.func(
@@ -336,8 +368,8 @@ class DuckDB(Dialect):
exp.cast(e.this, "timestamp", copy=True),
),
exp.ParseJSON: rename_func("JSON"),
- exp.PercentileCont: rename_func("QUANTILE_CONT"),
- exp.PercentileDisc: rename_func("QUANTILE_DISC"),
+ exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
+ exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"),
# DuckDB doesn't allow qualified columns inside of PIVOT expressions.
# See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62
exp.Pivot: transforms.preprocess([transforms.unqualify_columns]),
@@ -362,7 +394,9 @@ class DuckDB(Dialect):
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
exp.Timestamp: no_timestamp_sql,
- exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
+ exp.TimestampDiff: lambda self, e: self.func(
+ "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
+ ),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
@@ -373,11 +407,10 @@ class DuckDB(Dialect):
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
- f"'{e.args.get('unit') or 'day'}'",
+ f"'{e.args.get('unit') or 'DAY'}'",
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"),
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: _unix_to_time_sql,
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
@@ -410,6 +443,49 @@ class DuckDB(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
+ nano = expression.args.get("nano")
+ if nano is not None:
+ expression.set(
+ "sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0)
+ )
+
+ return rename_func("MAKE_TIME")(self, expression)
+
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ sec = expression.args["sec"]
+
+ milli = expression.args.get("milli")
+ if milli is not None:
+ sec += milli.pop() / exp.Literal.number(1000.0)
+
+ nano = expression.args.get("nano")
+ if nano is not None:
+ sec += nano.pop() / exp.Literal.number(1000000000.0)
+
+ if milli or nano:
+ expression.set("sec", sec)
+
+ return rename_func("MAKE_TIMESTAMP")(self, expression)
+
+ def tablesample_sql(
+ self,
+ expression: exp.TableSample,
+ sep: str = " AS ",
+ tablesample_keyword: t.Optional[str] = None,
+ ) -> str:
+ if not isinstance(expression.parent, exp.Select):
+ # This sample clause only applies to a single source, not the entire resulting relation
+ tablesample_keyword = "TABLESAMPLE"
+
+ return super().tablesample_sql(
+ expression, sep=sep, tablesample_keyword=tablesample_keyword
+ )
+
+ def getpath_sql(self, expression: exp.GetPath) -> str:
+ expression = prepend_dollar_to_path(expression)
+ return f"{self.sql(expression, 'this')} -> {self.sql(expression, 'expression')}"
+
def interval_sql(self, expression: exp.Interval) -> str:
multiplier: t.Optional[int] = None
unit = expression.text("unit").lower()
@@ -420,11 +496,14 @@ class DuckDB(Dialect):
multiplier = 90
if multiplier:
- return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})"
+ return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})"
return super().interval_sql(expression)
- def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS "
- ) -> str:
- return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep)
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
+ if isinstance(expression.parent, exp.UserDefinedFunction):
+ return self.sql(expression, "this")
+ return super().columndef_sql(expression, sep)
+
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
+ return f"${expression.name}" if expression.name else "?"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 65c85bb..dffa41e 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -418,13 +418,13 @@ class Hive(Dialect):
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
- TABLESAMPLE_SIZE_IS_PERCENT = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
INDEX_ON = "ON TABLE"
EXTRACT_ALLOWS_QUOTES = False
NVL2_SUPPORTED = False
+ LAST_DAY_SUPPORTS_DATE_PART = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@@ -523,7 +523,6 @@ class Hive(Dialect):
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
- exp.LastDateOfMonth: rename_func("LAST_DAY"),
exp.National: lambda self, e: self.national_sql(e, prefix=""),
exp.ClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
exp.NonClusteredColumnConstraint: lambda self, e: f"({self.expressions(e, 'this', indent=False)})",
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 5fe3d82..21a9657 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
isnull_to_is_null,
- json_keyvalue_comma_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@@ -21,6 +20,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
parse_date_delta_with_interval,
+ path_to_jsonpath,
rename_func,
strposition_to_locate_sql,
)
@@ -37,21 +37,21 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex
def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
- unit = expression.text("unit")
+ unit = expression.text("unit").upper()
- if unit == "day":
+ if unit == "DAY":
return f"DATE({expr})"
- if unit == "week":
+ if unit == "WEEK":
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
- elif unit == "month":
+ elif unit == "MONTH":
concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
- elif unit == "quarter":
+ elif unit == "QUARTER":
concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
- elif unit == "year":
+ elif unit == "YEAR":
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
@@ -292,9 +292,15 @@ class MySQL(Dialect):
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
+ "MAKETIME": exp.TimeFromParts.from_arg_list,
+ "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"MONTHNAME": lambda args: exp.TimeToStr(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
@@ -308,11 +314,6 @@ class MySQL(Dialect):
)
+ 1
),
- "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
- "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"WEEK": lambda args: exp.Week(
this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
),
@@ -441,6 +442,7 @@ class MySQL(Dialect):
}
LOG_DEFAULTS_TO_LN = True
+ STRING_ALIASES = True
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
@@ -620,13 +622,15 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
- NULL_ORDERING_SUPPORTED = False
+ NULL_ORDERING_SUPPORTED = None
JOIN_HINTS = False
TABLE_HINTS = True
DUPLICATE_KEY_UPDATE_WITH_SET = False
QUERY_HINT_SEP = " "
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
+ LAST_DAY_SUPPORTS_DATE_PART = False
+ JSON_KEY_VALUE_PAIR_SEP = ","
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -642,15 +646,16 @@ class MySQL(Dialect):
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
+ exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
- exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}",
+ exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
[
@@ -665,6 +670,7 @@ class MySQL(Dialect):
exp.StrToTime: _str_to_date_sql,
exp.Stuff: rename_func("INSERT"),
exp.TableSample: no_tablesample_sql,
+ exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 51dbd53..6ad3718 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -53,6 +53,7 @@ def to_char(args: t.List) -> exp.TimeToStr | exp.ToChar:
class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -81,6 +82,7 @@ class Oracle(Dialect):
"WW": "%W", # Week of year (1-53)
"YY": "%y", # 15
"YYYY": "%Y", # 2015
+ "FF6": "%f", # only 6 digits are supported in python formats
}
class Parser(parser.Parser):
@@ -91,6 +93,8 @@ class Oracle(Dialect):
**parser.Parser.FUNCTIONS,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_CHAR": to_char,
+ "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "oracle"),
+ "TO_DATE": format_time_lambda(exp.StrToDate, "oracle"),
}
FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
@@ -107,6 +111,11 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()),
+ }
+
TYPE_LITERAL_PARSERS = {
exp.DataType.Type.DATE: lambda self, this, _: self.expression(
exp.DateStrToDate, this=this
@@ -153,8 +162,10 @@ class Oracle(Dialect):
COLUMN_JOIN_MARKS_SUPPORTED = True
DATA_TYPE_SPECIFIERS_ALLOWED = True
ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
-
LIMIT_FETCH = "FETCH"
+ TABLESAMPLE_KEYWORDS = "SAMPLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
+ SUPPORTS_SELECT_INTO = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -186,6 +197,7 @@ class Oracle(Dialect):
]
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToDate: lambda self, e: f"TO_DATE({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
exp.Table: lambda self, e: self.table_sql(e, sep=" "),
@@ -201,6 +213,10 @@ class Oracle(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str:
+ this = expression.this
+ return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP"
+
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
@@ -233,8 +249,10 @@ class Oracle(Dialect):
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NVARCHAR2": TokenType.NVARCHAR,
+ "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY,
"SAMPLE": TokenType.TABLE_SAMPLE,
"START": TokenType.BEGIN,
+ "SYSDATE": TokenType.CURRENT_TIMESTAMP,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index e274877..1ca0a78 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -13,11 +13,12 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
format_time_lambda,
max_or_greatest,
+ merge_without_target_sql,
min_or_least,
+ no_last_day_sql,
no_map_from_entries_sql,
no_paren_current_date_sql,
no_pivot_sql,
- no_tablesample_sql,
no_trycast_sql,
parse_timestamp_trunc,
rename_func,
@@ -27,7 +28,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
trim_sql,
ts_or_ds_add_cast,
- ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
@@ -188,36 +188,6 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)
-def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
- def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
- """Remove table refs from columns in when statements."""
- if isinstance(expression, exp.Merge):
- alias = expression.this.args.get("alias")
-
- normalize = (
- lambda identifier: self.dialect.normalize_identifier(identifier).name
- if identifier
- else None
- )
-
- targets = {normalize(expression.this.this)}
-
- if alias:
- targets.add(normalize(alias.this))
-
- for when in expression.expressions:
- when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
- copy=False,
- )
-
- return expression
-
- return transforms.preprocess([_remove_target_from_merge])(self, expression)
-
-
class Postgres(Dialect):
INDEX_OFFSET = 1
TYPED_DIVISION = True
@@ -316,6 +286,8 @@ class Postgres(Dialect):
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
+ "MAKE_TIME": exp.TimeFromParts.from_arg_list,
+ "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"TO_TIMESTAMP": _to_timestamp,
@@ -387,12 +359,18 @@ class Postgres(Dialect):
class Generator(generator.Generator):
SINGLE_STRING_INTERVAL = True
+ RENAME_TABLE_WITH_DB = False
LOCKING_READS_SUPPORTED = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
NVL2_SUPPORTED = False
PARAMETER_TOKEN = "$"
+ TABLESAMPLE_SIZE_IS_ROWS = False
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ SUPPORTS_SELECT_INTO = True
+ # https://www.postgresql.org/docs/current/sql-createtable.html
+ SUPPORTS_UNLOGGED_TABLES = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -430,12 +408,13 @@ class Postgres(Dialect):
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
+ exp.LastDay: no_last_day_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
- exp.Merge: _merge_sql,
+ exp.Merge: merge_without_target_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
@@ -458,16 +437,16 @@ class Postgres(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
+ exp.TimeFromParts: rename_func("MAKE_TIME"),
+ exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.TableSample: no_tablesample_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("+"),
exp.TsOrDsDiff: _date_diff_sql,
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 360ab65..9b421e7 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_safe_divide_sql,
no_timestamp_sql,
+ path_to_jsonpath,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@@ -99,14 +100,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
- unit = exp.Literal.string(expression.text("unit") or "day")
+ unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
- unit = exp.Literal.string(expression.text("unit") or "day")
+ unit = exp.Literal.string(expression.text("unit") or "DAY")
return self.func("DATE_DIFF", unit, expr, this)
@@ -138,13 +139,6 @@ def _from_unixtime(args: t.List) -> exp.Expression:
return exp.UnixToTime.from_arg_list(args)
-def _parse_element_at(args: t.List) -> exp.Bracket:
- this = seq_get(args, 0)
- index = seq_get(args, 1)
- assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression)
- return exp.Bracket(this=this, expressions=[index], offset=1, safe=True)
-
-
def _unnest_sequence(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Table):
if isinstance(expression.this, exp.GenerateSeries):
@@ -175,15 +169,8 @@ def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str
timestamp = self.sql(expression, "this")
if scale in (None, exp.UnixToTime.SECONDS):
return rename_func("FROM_UNIXTIME")(self, expression)
- if scale == exp.UnixToTime.MILLIS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000)"
- if scale == exp.UnixToTime.MICROS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000)"
- if scale == exp.UnixToTime.NANOS:
- return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))"
def _to_int(expression: exp.Expression) -> exp.Expression:
@@ -215,6 +202,7 @@ class Presto(Dialect):
STRICT_STRING_CONCAT = True
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@@ -258,7 +246,9 @@ class Presto(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
- "ELEMENT_AT": _parse_element_at,
+ "ELEMENT_AT": lambda args: exp.Bracket(
+ this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True
+ ),
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
@@ -344,20 +334,20 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "day"),
+ exp.Literal.string(e.text("unit") or "DAY"),
_to_int(
e.expression,
),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "day"),
+ exp.Literal.string(e.text("unit") or "DAY"),
_to_int(e.expression * -1),
e.this,
),
@@ -366,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
+ exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
@@ -376,6 +367,7 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Last: _first_last_sql,
+ exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this),
exp.Lateral: _explode_to_unnest_sql,
exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
@@ -446,7 +438,7 @@ class Presto(Dialect):
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, self.KEY_VALUE_DEFINITONS) for arg in expression.expressions):
+ if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
self.unsupported("Struct with key-value definitions is unsupported.")
return self.function_fallback_sql(expression)
@@ -454,8 +446,8 @@ class Presto(Dialect):
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
- if expression.this and unit.lower().startswith("week"):
- return f"({expression.this.name} * INTERVAL '7' day)"
+ if expression.this and unit.startswith("WEEK"):
+ return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)
def transaction_sql(self, expression: exp.Transaction) -> str:
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 7382e7c..7194d81 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -9,8 +9,8 @@ from sqlglot.dialects.dialect import (
concat_ws_to_dpipe_sql,
date_delta_sql,
generatedasidentitycolumnconstraint_sql,
+ no_tablesample_sql,
rename_func,
- ts_or_ds_to_date_sql,
)
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
@@ -123,6 +123,27 @@ class Redshift(Postgres):
self._retreat(index)
return None
+ def _parse_query_modifiers(
+ self, this: t.Optional[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_query_modifiers(this)
+
+ if this:
+ refs = set()
+
+ for i, join in enumerate(this.args.get("joins", [])):
+ refs.add(
+ (
+ this.args["from"] if i == 0 else this.args["joins"][i - 1]
+ ).alias_or_name.lower()
+ )
+ table = join.this
+
+ if isinstance(table, exp.Table):
+ if table.parts[0].name.lower() in refs:
+ table.replace(table.to_column())
+ return this
+
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
@@ -144,11 +165,11 @@ class Redshift(Postgres):
class Generator(Postgres.Generator):
LOCKING_READS_SUPPORTED = False
- RENAME_TABLE_WITH_DB = False
QUERY_HINTS = False
VALUES_AS_TABLE = False
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = True
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -184,9 +205,9 @@ class Redshift(Postgres):
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"),
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
@@ -198,6 +219,9 @@ class Redshift(Postgres):
# Redshift supports ANY_VALUE(..)
TRANSFORMS.pop(exp.AnyValue)
+ # Redshift supports LAST_DAY(..)
+ TRANSFORMS.pop(exp.LastDay)
+
RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"}
def with_properties(self, properties: exp.Properties) -> str:
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 8925181..a8e4a42 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -19,7 +19,6 @@ from sqlglot.dialects.dialect import (
rename_func,
timestamptrunc_sql,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
var_map_sql,
)
from sqlglot.expressions import Literal
@@ -40,21 +39,7 @@ def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime,
if second_arg.is_string:
# case: <string_expr> [ , <format> ]
return format_time_lambda(exp.StrToTime, "snowflake")(args)
-
- # case: <numeric_expr> [ , <scale> ]
- if second_arg.name not in ["0", "3", "9"]:
- raise ValueError(
- f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
- )
-
- if second_arg.name == "0":
- timescale = exp.UnixToTime.SECONDS
- elif second_arg.name == "3":
- timescale = exp.UnixToTime.MILLIS
- elif second_arg.name == "9":
- timescale = exp.UnixToTime.NANOS
-
- return exp.UnixToTime(this=first_arg, scale=timescale)
+ return exp.UnixToTime(this=first_arg, scale=second_arg)
from sqlglot.optimizer.simplify import simplify_literals
@@ -91,23 +76,9 @@ def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
def _parse_datediff(args: t.List) -> exp.DateDiff:
- return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))
-
-
-def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str:
- scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
- if scale in (None, exp.UnixToTime.SECONDS):
- return f"TO_TIMESTAMP({timestamp})"
- if scale == exp.UnixToTime.MILLIS:
- return f"TO_TIMESTAMP({timestamp}, 3)"
- if scale == exp.UnixToTime.MICROS:
- return f"TO_TIMESTAMP({timestamp} / 1000, 3)"
- if scale == exp.UnixToTime.NANOS:
- return f"TO_TIMESTAMP({timestamp}, 9)"
-
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return exp.DateDiff(
+ this=seq_get(args, 2), expression=seq_get(args, 1), unit=_map_date_part(seq_get(args, 0))
+ )
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
@@ -120,14 +91,15 @@ def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]:
self._match(TokenType.COMMA)
expression = self._parse_bitwise()
-
+ this = _map_date_part(this)
name = this.name.upper()
+
if name.startswith("EPOCH"):
- if name.startswith("EPOCH_MILLISECOND"):
+ if name == "EPOCH_MILLISECOND":
scale = 10**3
- elif name.startswith("EPOCH_MICROSECOND"):
+ elif name == "EPOCH_MICROSECOND":
scale = 10**6
- elif name.startswith("EPOCH_NANOSECOND"):
+ elif name == "EPOCH_NANOSECOND":
scale = 10**9
else:
scale = None
@@ -204,6 +176,159 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser]
return _parse
+DATE_PART_MAPPING = {
+ "Y": "YEAR",
+ "YY": "YEAR",
+ "YYY": "YEAR",
+ "YYYY": "YEAR",
+ "YR": "YEAR",
+ "YEARS": "YEAR",
+ "YRS": "YEAR",
+ "MM": "MONTH",
+ "MON": "MONTH",
+ "MONS": "MONTH",
+ "MONTHS": "MONTH",
+ "D": "DAY",
+ "DD": "DAY",
+ "DAYS": "DAY",
+ "DAYOFMONTH": "DAY",
+ "WEEKDAY": "DAYOFWEEK",
+ "DOW": "DAYOFWEEK",
+ "DW": "DAYOFWEEK",
+ "WEEKDAY_ISO": "DAYOFWEEKISO",
+ "DOW_ISO": "DAYOFWEEKISO",
+ "DW_ISO": "DAYOFWEEKISO",
+ "YEARDAY": "DAYOFYEAR",
+ "DOY": "DAYOFYEAR",
+ "DY": "DAYOFYEAR",
+ "W": "WEEK",
+ "WK": "WEEK",
+ "WEEKOFYEAR": "WEEK",
+ "WOY": "WEEK",
+ "WY": "WEEK",
+ "WEEK_ISO": "WEEKISO",
+ "WEEKOFYEARISO": "WEEKISO",
+ "WEEKOFYEAR_ISO": "WEEKISO",
+ "Q": "QUARTER",
+ "QTR": "QUARTER",
+ "QTRS": "QUARTER",
+ "QUARTERS": "QUARTER",
+ "H": "HOUR",
+ "HH": "HOUR",
+ "HR": "HOUR",
+ "HOURS": "HOUR",
+ "HRS": "HOUR",
+ "M": "MINUTE",
+ "MI": "MINUTE",
+ "MIN": "MINUTE",
+ "MINUTES": "MINUTE",
+ "MINS": "MINUTE",
+ "S": "SECOND",
+ "SEC": "SECOND",
+ "SECONDS": "SECOND",
+ "SECS": "SECOND",
+ "MS": "MILLISECOND",
+ "MSEC": "MILLISECOND",
+ "MILLISECONDS": "MILLISECOND",
+ "US": "MICROSECOND",
+ "USEC": "MICROSECOND",
+ "MICROSECONDS": "MICROSECOND",
+ "NS": "NANOSECOND",
+ "NSEC": "NANOSECOND",
+ "NANOSEC": "NANOSECOND",
+ "NSECOND": "NANOSECOND",
+ "NSECONDS": "NANOSECOND",
+ "NANOSECS": "NANOSECOND",
+ "NSECONDS": "NANOSECOND",
+ "EPOCH": "EPOCH_SECOND",
+ "EPOCH_SECONDS": "EPOCH_SECOND",
+ "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND",
+ "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND",
+ "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND",
+ "TZH": "TIMEZONE_HOUR",
+ "TZM": "TIMEZONE_MINUTE",
+}
+
+
+@t.overload
+def _map_date_part(part: exp.Expression) -> exp.Var:
+ pass
+
+
+@t.overload
+def _map_date_part(part: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ pass
+
+
+def _map_date_part(part):
+ mapped = DATE_PART_MAPPING.get(part.name.upper()) if part else None
+ return exp.var(mapped) if mapped else part
+
+
+def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
+ trunc = date_trunc_to_time(args)
+ trunc.set("unit", _map_date_part(trunc.args["unit"]))
+ return trunc
+
+
+def _parse_colon_get_path(
+ self: parser.Parser, this: t.Optional[exp.Expression]
+) -> t.Optional[exp.Expression]:
+ while True:
+ path = self._parse_bitwise()
+
+ # The cast :: operator has a lower precedence than the extraction operator :, so
+ # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
+ if isinstance(path, exp.Cast):
+ target_type = path.to
+ path = path.this
+ else:
+ target_type = None
+
+ if isinstance(path, exp.Expression):
+ path = exp.Literal.string(path.sql(dialect="snowflake"))
+
+ # The extraction operator : is left-associative
+ this = self.expression(exp.GetPath, this=this, expression=path)
+
+ if target_type:
+ this = exp.cast(this, target_type)
+
+ if not self._match(TokenType.COLON):
+ break
+
+ if self._match_set(self.RANGE_PARSERS):
+ this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
+
+ return this
+
+
+def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
+ if len(args) == 2:
+ # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept,
+ # so we parse this into Anonymous for now instead of introducing complexity
+ return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args)
+
+ return exp.TimestampFromParts.from_arg_list(args)
+
+
+def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Snowflake doesn't allow columns referenced in UNPIVOT to be qualified,
+ so we need to unqualify them.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))")
+ >>> print(_unqualify_unpivot_columns(expr).sql(dialect="snowflake"))
+ SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april))
+ """
+ if isinstance(expression, exp.Pivot) and expression.unpivot:
+ expression = transforms.unqualify_columns(expression)
+
+ return expression
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -211,6 +336,8 @@ class Snowflake(Dialect):
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
+ PREFER_CTE_ALIAS_COLUMN = True
+ TABLESAMPLE_SIZE_IS_PERCENT = True
TIME_MAPPING = {
"YYYY": "%Y",
@@ -276,14 +403,19 @@ class Snowflake(Dialect):
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
"CONVERT_TIMEZONE": _parse_convert_timezone,
- "DATE_TRUNC": date_trunc_to_time,
+ "DATE_TRUNC": _date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=_map_date_part(seq_get(args, 0)),
),
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
"FLATTEN": exp.Explode.from_arg_list,
"IFF": exp.If.from_arg_list,
+ "LAST_DAY": lambda args: exp.LastDay(
+ this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
+ ),
"LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": _parse_object_construct,
@@ -293,6 +425,8 @@ class Snowflake(Dialect):
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TIMEDIFF": _parse_datediff,
"TIMESTAMPDIFF": _parse_datediff,
+ "TIMESTAMPFROMPARTS": _parse_timestamp_from_parts,
+ "TIMESTAMP_FROM_PARTS": _parse_timestamp_from_parts,
"TO_TIMESTAMP": _parse_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _zeroifnull_to_if,
@@ -301,22 +435,17 @@ class Snowflake(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
+ "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(),
}
FUNCTION_PARSERS.pop("TRIM")
- COLUMN_OPERATORS = {
- **parser.Parser.COLUMN_OPERATORS,
- TokenType.COLON: lambda self, this, path: self.expression(
- exp.Bracket, this=this, expressions=[path]
- ),
- }
-
TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME}
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
+ TokenType.COLON: _parse_colon_get_path,
}
ALTER_PARSERS = {
@@ -344,6 +473,7 @@ class Snowflake(Dialect):
SHOW_PARSERS = {
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
+ "COLUMNS": _show_parser("COLUMNS"),
}
STAGED_FILE_SINGLE_TOKENS = {
@@ -351,8 +481,18 @@ class Snowflake(Dialect):
TokenType.MOD,
TokenType.SLASH,
}
+
FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
+ def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
+ if is_map:
+ # Keys are strings in Snowflake's objects, see also:
+ # - https://docs.snowflake.com/en/sql-reference/data-types-semistructured
+ # - https://docs.snowflake.com/en/sql-reference/functions/object_construct
+ return self._parse_slice(self._parse_string())
+
+ return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
+
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
lateral = super()._parse_lateral()
if not lateral:
@@ -440,6 +580,8 @@ class Snowflake(Dialect):
scope = None
scope_kind = None
+ like = self._parse_string() if self._match(TokenType.LIKE) else None
+
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
@@ -451,7 +593,9 @@ class Snowflake(Dialect):
scope_kind = "TABLE"
scope = self._parse_table()
- return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
+ return self.expression(
+ exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
+ )
def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
@@ -489,8 +633,12 @@ class Snowflake(Dialect):
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
+ "REMOVE": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
+ "RM": TokenType.COMMAND,
"SAMPLE": TokenType.TABLE_SAMPLE,
+ "SQL_DOUBLE": TokenType.DOUBLE,
+ "SQL_VARCHAR": TokenType.VARCHAR,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -518,6 +666,8 @@ class Snowflake(Dialect):
SUPPORTS_TABLE_COPY = False
COLLATE_IS_FUNC = True
LIMIT_ONLY_LITERALS = True
+ JSON_KEY_VALUE_PAIR_SEP = ","
+ INSERT_OVERWRITE = " OVERWRITE INTO"
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -545,6 +695,8 @@ class Snowflake(Dialect):
),
exp.GroupConcat: rename_func("LISTAGG"),
exp.If: if_sql(name="IFF", false_value="NULL"),
+ exp.JSONExtract: lambda self, e: f"{self.sql(e, 'this')}[{self.sql(e, 'expression')}]",
+ exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions),
exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.LogicalOr: rename_func("BOOLOR_AGG"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
@@ -557,6 +709,7 @@ class Snowflake(Dialect):
exp.PercentileDisc: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
),
+ exp.Pivot: transforms.preprocess([_unqualify_unpivot_columns]),
exp.RegexpILike: _regexpilike_sql,
exp.Rand: rename_func("RANDOM"),
exp.Select: transforms.preprocess(
@@ -578,6 +731,9 @@ class Snowflake(Dialect):
*(arg for expression in e.expressions for arg in expression.flatten()),
),
exp.Stuff: rename_func("INSERT"),
+ exp.TimestampDiff: lambda self, e: self.func(
+ "TIMESTAMPDIFF", e.unit, e.expression, e.this
+ ),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
@@ -589,8 +745,7 @@ class Snowflake(Dialect):
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"),
- exp.UnixToTime: _unix_to_time_sql,
+ exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.Xor: rename_func("BOOLXOR"),
@@ -612,6 +767,14 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ milli = expression.args.get("milli")
+ if milli is not None:
+ milli_to_nano = milli.pop() * exp.Literal.number(1000000)
+ expression.set("nano", milli_to_nano)
+
+ return rename_func("TIMESTAMP_FROM_PARTS")(self, expression)
+
def trycast_sql(self, expression: exp.TryCast) -> str:
value = expression.this
@@ -657,6 +820,9 @@ class Snowflake(Dialect):
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
+ like = self.sql(expression, "like")
+ like = f" LIKE {like}" if like else ""
+
scope = self.sql(expression, "scope")
scope = f" {scope}" if scope else ""
@@ -664,7 +830,7 @@ class Snowflake(Dialect):
if scope_kind:
scope_kind = f" IN {scope_kind}"
- return f"SHOW {expression.name}{scope_kind}{scope}"
+ return f"SHOW {expression.name}{like}{scope_kind}{scope}"
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index aa09f53..e27ba18 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -48,11 +48,8 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
return f"TIMESTAMP_MILLIS({timestamp})"
if scale == exp.UnixToTime.MICROS:
return f"TIMESTAMP_MICROS({timestamp})"
- if scale == exp.UnixToTime.NANOS:
- return f"TIMESTAMP_SECONDS({timestamp} / 1000000000)"
- self.unsupported(f"Unsupported scale for timestamp: {scale}.")
- return ""
+ return f"TIMESTAMP_SECONDS({timestamp} / POW(10, {scale}))"
def _unalias_pivot(expression: exp.Expression) -> exp.Expression:
@@ -93,12 +90,7 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1'))
"""
if isinstance(expression, exp.Pivot):
- expression.args["field"].transform(
- lambda node: exp.column(node.output_name, quoted=node.this.quoted)
- if isinstance(node, exp.Column)
- else node,
- copy=False,
- )
+ expression.set("field", transforms.unqualify_columns(expression.args["field"]))
return expression
@@ -234,7 +226,7 @@ class Spark2(Hive):
def struct_sql(self, expression: exp.Struct) -> str:
args = []
for arg in expression.expressions:
- if isinstance(arg, self.KEY_VALUE_DEFINITONS):
+ if isinstance(arg, self.KEY_VALUE_DEFINITIONS):
if isinstance(arg, exp.Bracket):
args.append(exp.alias_(arg.this, arg.expressions[0].name))
else:
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 9bac51c..244a96e 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -78,6 +78,7 @@ class SQLite(Dialect):
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
+ STRING_ALIASES = True
class Generator(generator.Generator):
JOIN_HINTS = False
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 0ccc567..6dbad15 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -175,6 +175,8 @@ class Teradata(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
+ TABLESAMPLE_KEYWORDS = "SAMPLE"
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -214,7 +216,10 @@ class Teradata(Dialect):
return self.cast_sql(expression, safe_prefix="TRY")
def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ self,
+ expression: exp.TableSample,
+ sep: str = " AS ",
+ tablesample_keyword: t.Optional[str] = None,
) -> str:
return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}"
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 3682ac7..eddb70a 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from sqlglot import exp
+from sqlglot.dialects.dialect import merge_without_target_sql
from sqlglot.dialects.presto import Presto
@@ -11,6 +12,7 @@ class Trino(Presto):
TRANSFORMS = {
**Presto.Generator.TRANSFORMS,
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.Merge: merge_without_target_sql,
}
class Tokenizer(Presto.Tokenizer):
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 165a703..b9c347c 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
parse_date_delta,
+ path_to_jsonpath,
rename_func,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
+ trim_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(this=this, format=fmt, culture=culture)
-def _parse_eomonth(args: t.List) -> exp.Expression:
- date = seq_get(args, 0)
+def _parse_eomonth(args: t.List) -> exp.LastDay:
+ date = exp.TsOrDsToDate(this=seq_get(args, 0))
month_lag = seq_get(args, 1)
- unit = DATE_DELTA_INTERVAL.get("month")
if month_lag is None:
- return exp.LastDateOfMonth(this=date)
+ this: exp.Expression = date
+ else:
+ unit = DATE_DELTA_INTERVAL.get("month")
+ this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit))
- # Remove month lag argument in parser as its compared with the number of arguments of the resulting class
- args.remove(month_lag)
-
- return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
+ return exp.LastDay(this=this)
def _parse_hashbytes(args: t.List) -> exp.Expression:
@@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
return exp.func("HASHBYTES", *args)
-DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
+DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"}
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
- fmt = (
- expression.args["format"]
- if isinstance(expression, exp.NumberToStr)
- else exp.Literal.string(
- format_time(
- expression.text("format"),
- t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
- )
- )
- )
+ fmt = expression.args["format"]
- # There is no format for "quarter"
- if fmt.name.lower() in DATEPART_ONLY_FORMATS:
- return self.func("DATEPART", fmt.name, expression.this)
+ if not isinstance(expression, exp.NumberToStr):
+ if fmt.is_string:
+ mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
- return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
+ name = (mapped_fmt or "").upper()
+ if name in DATEPART_ONLY_FORMATS:
+ return self.func("DATEPART", name, expression.this)
+
+ fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
+ else:
+ fmt_sql = self.format_time(expression) or self.sql(fmt)
+ else:
+ fmt_sql = self.sql(fmt)
+
+ return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture"))
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
@@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
return expression
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
+def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
+ return exp.TimestampFromParts(
+ year=seq_get(args, 0),
+ month=seq_get(args, 1),
+ day=seq_get(args, 2),
+ hour=seq_get(args, 3),
+ min=seq_get(args, 4),
+ sec=seq_get(args, 5),
+ milli=seq_get(args, 6),
+ )
+
+
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
+def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
+ return exp.TimeFromParts(
+ hour=seq_get(args, 0),
+ min=seq_get(args, 1),
+ sec=seq_get(args, 2),
+ fractions=seq_get(args, 3),
+ precision=seq_get(args, 4),
+ )
+
+
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@@ -352,7 +377,7 @@ class TSQL(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- IDENTIFIERS = ['"', ("[", "]")]
+ IDENTIFIERS = [("[", "]"), '"']
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
VAR_SINGLE_TOKENS = {"@", "$", "#"}
@@ -362,6 +387,7 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.COMMAND,
+ "EXEC": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
@@ -397,6 +423,7 @@ class TSQL(Dialect):
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
+ "DATETIMEFROMPARTS": _parse_datetimefromparts,
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
@@ -411,6 +438,7 @@ class TSQL(Dialect):
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
+ "TIMEFROMPARTS": _parse_timefromparts,
}
JOIN_HINTS = {
@@ -440,6 +468,7 @@ class TSQL(Dialect):
LOG_DEFAULTS_TO_LN = True
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
+ STRING_ALIASES = True
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -630,8 +659,10 @@ class TSQL(Dialect):
COMPUTED_COLUMN_WITH_TYPE = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
ENSURE_BOOLS = True
- NULL_ORDERING_SUPPORTED = False
+ NULL_ORDERING_SUPPORTED = None
SUPPORTS_SINGLE_ARG_CONCAT = False
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ SUPPORTS_SELECT_INTO = True
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -667,13 +698,16 @@ class TSQL(Dialect):
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
+ exp.GetPath: path_to_jsonpath("JSON_VALUE"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
+ exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
+ exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
@@ -689,9 +723,9 @@ class TSQL(Dialect):
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.Trim: trim_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
TRANSFORMS.pop(exp.ReturnsProperty)
@@ -701,6 +735,46 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def lateral_op(self, expression: exp.Lateral) -> str:
+ cross_apply = expression.args.get("cross_apply")
+ if cross_apply is True:
+ return "CROSS APPLY"
+ if cross_apply is False:
+ return "OUTER APPLY"
+
+ # TODO: perhaps we can check if the parent is a Join and transpile it appropriately
+ self.unsupported("LATERAL clause is not supported.")
+ return "LATERAL"
+
+ def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
+ nano = expression.args.get("nano")
+ if nano is not None:
+ nano.pop()
+ self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.")
+
+ if expression.args.get("fractions") is None:
+ expression.set("fractions", exp.Literal.number(0))
+ if expression.args.get("precision") is None:
+ expression.set("precision", exp.Literal.number(0))
+
+ return rename_func("TIMEFROMPARTS")(self, expression)
+
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ zone = expression.args.get("zone")
+ if zone is not None:
+ zone.pop()
+ self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.")
+
+ nano = expression.args.get("nano")
+ if nano is not None:
+ nano.pop()
+ self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.")
+
+ if expression.args.get("milli") is None:
+ expression.set("milli", exp.Literal.number(0))
+
+ return rename_func("DATETIMEFROMPARTS")(self, expression)
+
def set_operation(self, expression: exp.Union, op: str) -> str:
limit = expression.args.get("limit")
if limit:
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index b79a551..6c01edc 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -132,11 +132,10 @@ def ordered(this, desc, nulls_first):
@null_if_any
def interval(this, unit):
- unit = unit.lower()
- plural = unit + "s"
+ plural = unit + "S"
if plural in Generator.TIME_PART_SINGULARS:
unit = plural
- return datetime.timedelta(**{unit: float(this)})
+ return datetime.timedelta(**{unit.lower(): float(this)})
@null_if_any("this", "expression")
@@ -176,6 +175,7 @@ ENV = {
"DOT": null_if_any(lambda e, this: e[this]),
"EQ": null_if_any(lambda this, e: this == e),
"EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
+ "GETPATH": null_if_any(lambda this, e: this.get(e)),
"GT": null_if_any(lambda this, e: this > e),
"GTE": null_if_any(lambda this, e: this >= e),
"IF": lambda predicate, true, false: true if predicate else false,
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index ea2255d..ddad8f8 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -16,6 +16,7 @@ import datetime
import math
import numbers
import re
+import textwrap
import typing as t
from collections import deque
from copy import deepcopy
@@ -35,6 +36,8 @@ from sqlglot.helper import (
from sqlglot.tokens import Token
if t.TYPE_CHECKING:
+ from typing_extensions import Literal as Lit
+
from sqlglot.dialects.dialect import DialectType
@@ -242,6 +245,9 @@ class Expression(metaclass=_Expression):
def is_type(self, *dtypes) -> bool:
return self.type is not None and self.type.is_type(*dtypes)
+ def is_leaf(self) -> bool:
+ return not any(isinstance(v, (Expression, list)) for v in self.args.values())
+
@property
def meta(self) -> t.Dict[str, t.Any]:
if self._meta is None:
@@ -497,7 +503,14 @@ class Expression(metaclass=_Expression):
return self.sql()
def __repr__(self) -> str:
- return self._to_s()
+ return _to_s(self)
+
+ def to_s(self) -> str:
+ """
+ Same as __repr__, but includes additional information which can be useful
+ for debugging, like empty or missing args and the AST nodes' object IDs.
+ """
+ return _to_s(self, verbose=True)
def sql(self, dialect: DialectType = None, **opts) -> str:
"""
@@ -514,30 +527,6 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect).generate(self, **opts)
- def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
- indent = "" if not level else "\n"
- indent += "".join([" "] * level)
- left = f"({self.key.upper()} "
-
- args: t.Dict[str, t.Any] = {
- k: ", ".join(
- v._to_s(hide_missing=hide_missing, level=level + 1)
- if hasattr(v, "_to_s")
- else str(v)
- for v in ensure_list(vs)
- if v is not None
- )
- for k, vs in self.args.items()
- }
- args["comments"] = self.comments
- args["type"] = self.type
- args = {k: v for k, v in args.items() if v or not hide_missing}
-
- right = ", ".join(f"{k}: {v}" for k, v in args.items())
- right += ")"
-
- return indent + left + right
-
def transform(self, fun, *args, copy=True, **kwargs):
"""
Recursively visits all tree nodes (excluding already transformed ones)
@@ -580,8 +569,9 @@ class Expression(metaclass=_Expression):
For example::
>>> tree = Select().select("x").from_("tbl")
- >>> tree.find(Column).replace(Column(this="y"))
- (COLUMN this: y)
+ >>> tree.find(Column).replace(column("y"))
+ Column(
+ this=Identifier(this=y, quoted=False))
>>> tree.sql()
'SELECT y FROM tbl'
@@ -831,6 +821,9 @@ class Expression(metaclass=_Expression):
div.args["safe"] = safe
return div
+ def desc(self, nulls_first: bool = False) -> Ordered:
+ return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first)
+
def __lt__(self, other: t.Any) -> LT:
return self._binop(LT, other)
@@ -1109,7 +1102,7 @@ class Clone(Expression):
class Describe(Expression):
- arg_types = {"this": True, "kind": False, "expressions": False}
+ arg_types = {"this": True, "extended": False, "kind": False, "expressions": False}
class Kill(Expression):
@@ -1124,6 +1117,10 @@ class Set(Expression):
arg_types = {"expressions": False, "unset": False, "tag": False}
+class Heredoc(Expression):
+ arg_types = {"this": True, "tag": False}
+
+
class SetItem(Expression):
arg_types = {
"this": False,
@@ -1937,7 +1934,13 @@ class Join(Expression):
class Lateral(UDTF):
- arg_types = {"this": True, "view": False, "outer": False, "alias": False}
+ arg_types = {
+ "this": True,
+ "view": False,
+ "outer": False,
+ "alias": False,
+ "cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY
+ }
class MatchRecognize(Expression):
@@ -1964,7 +1967,12 @@ class Offset(Expression):
class Order(Expression):
- arg_types = {"this": False, "expressions": True, "interpolate": False}
+ arg_types = {
+ "this": False,
+ "expressions": True,
+ "interpolate": False,
+ "siblings": False,
+ }
# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier
@@ -2002,6 +2010,11 @@ class AutoIncrementProperty(Property):
arg_types = {"this": True}
+# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html
+class AutoRefreshProperty(Property):
+ arg_types = {"this": True}
+
+
class BlockCompressionProperty(Property):
arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True}
@@ -2259,6 +2272,10 @@ class SortKeyProperty(Property):
arg_types = {"this": True, "compound": False}
+class SqlReadWriteProperty(Property):
+ arg_types = {"this": True}
+
+
class SqlSecurityProperty(Property):
arg_types = {"definer": True}
@@ -2543,7 +2560,6 @@ class Table(Expression):
"version": False,
"format": False,
"pattern": False,
- "index": False,
"ordinality": False,
"when": False,
}
@@ -2585,6 +2601,14 @@ class Table(Expression):
return parts
+ def to_column(self, copy: bool = True) -> Alias | Column | Dot:
+ parts = self.parts
+ col = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore
+ alias = self.args.get("alias")
+ if alias:
+ col = alias_(col, alias.this, copy=copy)
+ return col
+
class Union(Subqueryable):
arg_types = {
@@ -2694,6 +2718,14 @@ class Unnest(UDTF):
"offset": False,
}
+ @property
+ def selects(self) -> t.List[Expression]:
+ columns = super().selects
+ offset = self.args.get("offset")
+ if offset:
+ columns = columns + [to_identifier("offset") if offset is True else offset]
+ return columns
+
class Update(Expression):
arg_types = {
@@ -3368,7 +3400,7 @@ class Select(Subqueryable):
return Create(
this=table_expression,
- kind="table",
+ kind="TABLE",
expression=instance,
properties=properties_expression,
)
@@ -3488,7 +3520,6 @@ class TableSample(Expression):
"rows": False,
"size": False,
"seed": False,
- "kind": False,
}
@@ -3517,6 +3548,10 @@ class Pivot(Expression):
"include_nulls": False,
}
+ @property
+ def unpivot(self) -> bool:
+ return bool(self.args.get("unpivot"))
+
class Window(Condition):
arg_types = {
@@ -3604,6 +3639,7 @@ class DataType(Expression):
BOOLEAN = auto()
CHAR = auto()
DATE = auto()
+ DATE32 = auto()
DATEMULTIRANGE = auto()
DATERANGE = auto()
DATETIME = auto()
@@ -3631,6 +3667,8 @@ class DataType(Expression):
INTERVAL = auto()
IPADDRESS = auto()
IPPREFIX = auto()
+ IPV4 = auto()
+ IPV6 = auto()
JSON = auto()
JSONB = auto()
LONGBLOB = auto()
@@ -3729,6 +3767,7 @@ class DataType(Expression):
Type.TIMESTAMP_MS,
Type.TIMESTAMP_NS,
Type.DATE,
+ Type.DATE32,
Type.DATETIME,
Type.DATETIME64,
}
@@ -4100,6 +4139,12 @@ class Alias(Expression):
return self.alias
+# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but
+# other dialects require identifiers. This enables us to transpile between them easily.
+class PivotAlias(Alias):
+ pass
+
+
class Aliases(Expression):
arg_types = {"this": True, "expressions": True}
@@ -4108,6 +4153,11 @@ class Aliases(Expression):
return self.expressions
+# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html
+class AtIndex(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
class AtTimeZone(Expression):
arg_types = {"this": True, "zone": True}
@@ -4154,16 +4204,16 @@ class TimeUnit(Expression):
arg_types = {"unit": False}
UNABBREVIATED_UNIT_NAME = {
- "d": "day",
- "h": "hour",
- "m": "minute",
- "ms": "millisecond",
- "ns": "nanosecond",
- "q": "quarter",
- "s": "second",
- "us": "microsecond",
- "w": "week",
- "y": "year",
+ "D": "DAY",
+ "H": "HOUR",
+ "M": "MINUTE",
+ "MS": "MILLISECOND",
+ "NS": "NANOSECOND",
+ "Q": "QUARTER",
+ "S": "SECOND",
+ "US": "MICROSECOND",
+ "W": "WEEK",
+ "Y": "YEAR",
}
VAR_LIKE = (Column, Literal, Var)
@@ -4171,9 +4221,11 @@ class TimeUnit(Expression):
def __init__(self, **args):
unit = args.get("unit")
if isinstance(unit, self.VAR_LIKE):
- args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
+ args["unit"] = Var(
+ this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
+ )
elif isinstance(unit, Week):
- unit.set("this", Var(this=unit.this.name))
+ unit.set("this", Var(this=unit.this.name.upper()))
super().__init__(**args)
@@ -4301,6 +4353,20 @@ class Anonymous(Func):
is_var_len_args = True
+class AnonymousAggFunc(AggFunc):
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
+
+
+# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators
+class CombinedAggFunc(AnonymousAggFunc):
+ arg_types = {"this": True, "expressions": False, "parts": True}
+
+
+class CombinedParameterizedAgg(ParameterizedAgg):
+ arg_types = {"this": True, "expressions": True, "params": True, "parts": True}
+
+
# https://docs.snowflake.com/en/sql-reference/functions/hll
# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html
class Hll(AggFunc):
@@ -4381,7 +4447,7 @@ class ArraySort(Func):
class ArraySum(Func):
- pass
+ arg_types = {"this": True, "expression": False}
class ArrayUnionAgg(AggFunc):
@@ -4498,7 +4564,7 @@ class Count(AggFunc):
class CountIf(AggFunc):
- pass
+ _sql_names = ["COUNT_IF", "COUNTIF"]
class CurrentDate(Func):
@@ -4537,6 +4603,17 @@ class DateDiff(Func, TimeUnit):
class DateTrunc(Func):
arg_types = {"unit": True, "this": True, "zone": False}
+ def __init__(self, **args):
+ unit = args.get("unit")
+ if isinstance(unit, TimeUnit.VAR_LIKE):
+ args["unit"] = Literal.string(
+ (TimeUnit.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper()
+ )
+ elif isinstance(unit, Week):
+ unit.set("this", Literal.string(unit.this.name.upper()))
+
+ super().__init__(**args)
+
@property
def unit(self) -> Expression:
return self.args["unit"]
@@ -4582,8 +4659,9 @@ class MonthsBetween(Func):
arg_types = {"this": True, "expression": True, "roundoff": False}
-class LastDateOfMonth(Func):
- pass
+class LastDay(Func, TimeUnit):
+ _sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"]
+ arg_types = {"this": True, "unit": False}
class Extract(Func):
@@ -4627,10 +4705,22 @@ class TimeTrunc(Func, TimeUnit):
class DateFromParts(Func):
- _sql_names = ["DATEFROMPARTS"]
+ _sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"]
arg_types = {"year": True, "month": True, "day": True}
+class TimeFromParts(Func):
+ _sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"]
+ arg_types = {
+ "hour": True,
+ "min": True,
+ "sec": True,
+ "nano": False,
+ "fractions": False,
+ "precision": False,
+ }
+
+
class DateStrToDate(Func):
pass
@@ -4754,6 +4844,16 @@ class JSONObject(Func):
}
+class JSONObjectAgg(AggFunc):
+ arg_types = {
+ "expressions": False,
+ "null_handling": False,
+ "unique_keys": False,
+ "return_type": False,
+ "encoding": False,
+ }
+
+
# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html
class JSONArray(Func):
arg_types = {
@@ -4841,6 +4941,15 @@ class ParseJSON(Func):
is_var_len_args = True
+# https://docs.snowflake.com/en/sql-reference/functions/get_path
+class GetPath(Func):
+ arg_types = {"this": True, "expression": True}
+
+ @property
+ def output_name(self) -> str:
+ return self.expression.output_name
+
+
class Least(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -5026,7 +5135,7 @@ class RegexpReplace(Func):
arg_types = {
"this": True,
"expression": True,
- "replacement": True,
+ "replacement": False,
"position": False,
"occurrence": False,
"parameters": False,
@@ -5052,8 +5161,10 @@ class Repeat(Func):
arg_types = {"this": True, "times": True}
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16
+# tsql third argument function == trunctaion if not 0
class Round(Func):
- arg_types = {"this": True, "decimals": False}
+ arg_types = {"this": True, "decimals": False, "truncate": False}
class RowNumber(Func):
@@ -5228,6 +5339,10 @@ class TsOrDsToDate(Func):
arg_types = {"this": True, "format": False}
+class TsOrDsToTime(Func):
+ pass
+
+
class TsOrDiToDi(Func):
pass
@@ -5236,6 +5351,11 @@ class Unhex(Func):
pass
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date
+class UnixDate(Func):
+ pass
+
+
class UnixToStr(Func):
arg_types = {"this": True, "format": False}
@@ -5245,10 +5365,16 @@ class UnixToStr(Func):
class UnixToTime(Func):
arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
- SECONDS = Literal.string("seconds")
- MILLIS = Literal.string("millis")
- MICROS = Literal.string("micros")
- NANOS = Literal.string("nanos")
+ SECONDS = Literal.number(0)
+ DECIS = Literal.number(1)
+ CENTIS = Literal.number(2)
+ MILLIS = Literal.number(3)
+ DECIMILLIS = Literal.number(4)
+ CENTIMILLIS = Literal.number(5)
+ MICROS = Literal.number(6)
+ DECIMICROS = Literal.number(7)
+ CENTIMICROS = Literal.number(8)
+ NANOS = Literal.number(9)
class UnixToTimeStr(Func):
@@ -5256,8 +5382,7 @@ class UnixToTimeStr(Func):
class TimestampFromParts(Func):
- """Constructs a timestamp given its constituent parts."""
-
+ _sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"]
arg_types = {
"year": True,
"month": True,
@@ -5265,6 +5390,9 @@ class TimestampFromParts(Func):
"hour": True,
"min": True,
"sec": True,
+ "nano": False,
+ "zone": False,
+ "milli": False,
}
@@ -5358,9 +5486,9 @@ def maybe_parse(
Example:
>>> maybe_parse("1")
- (LITERAL this: 1, is_string: False)
+ Literal(this=1, is_string=False)
>>> maybe_parse(to_identifier("x"))
- (IDENTIFIER this: x, quoted: False)
+ Identifier(this=x, quoted=False)
Args:
sql_or_expression: the SQL code string or an expression
@@ -5407,6 +5535,39 @@ def maybe_copy(instance, copy=True):
return instance.copy() if copy and instance else instance
+def _to_s(node: t.Any, verbose: bool = False, level: int = 0) -> str:
+ """Generate a textual representation of an Expression tree"""
+ indent = "\n" + (" " * (level + 1))
+ delim = f",{indent}"
+
+ if isinstance(node, Expression):
+ args = {k: v for k, v in node.args.items() if (v is not None and v != []) or verbose}
+
+ if (node.type or verbose) and not isinstance(node, DataType):
+ args["_type"] = node.type
+ if node.comments or verbose:
+ args["_comments"] = node.comments
+
+ if verbose:
+ args["_id"] = id(node)
+
+ # Inline leaves for a more compact representation
+ if node.is_leaf():
+ indent = ""
+ delim = ", "
+
+ items = delim.join([f"{k}={_to_s(v, verbose, level + 1)}" for k, v in args.items()])
+ return f"{node.__class__.__name__}({indent}{items})"
+
+ if isinstance(node, list):
+ items = delim.join(_to_s(i, verbose, level + 1) for i in node)
+ items = f"{indent}{items}" if items else ""
+ return f"[{items}]"
+
+ # Indent multiline strings to match the current level
+ return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines())
+
+
def _is_wrong_expression(expression, into):
return isinstance(expression, Expression) and not isinstance(expression, into)
@@ -5816,7 +5977,7 @@ def delete(
def insert(
expression: ExpOrStr,
into: ExpOrStr,
- columns: t.Optional[t.Sequence[ExpOrStr]] = None,
+ columns: t.Optional[t.Sequence[str | Identifier]] = None,
overwrite: t.Optional[bool] = None,
returning: t.Optional[ExpOrStr] = None,
dialect: DialectType = None,
@@ -5847,15 +6008,7 @@ def insert(
this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)
if columns:
- this = _apply_list_builder(
- *columns,
- instance=Schema(this=this),
- arg="expressions",
- into=Identifier,
- copy=False,
- dialect=dialect,
- **opts,
- )
+ this = Schema(this=this, expressions=[to_identifier(c, copy=copy) for c in columns])
insert = Insert(this=this, expression=expr, overwrite=overwrite)
@@ -6073,7 +6226,7 @@ def to_interval(interval: str | Literal) -> Interval:
return Interval(
this=Literal.string(interval_parts.group(1)),
- unit=Var(this=interval_parts.group(2)),
+ unit=Var(this=interval_parts.group(2).upper()),
)
@@ -6219,13 +6372,44 @@ def subquery(
return Select().from_(expression, dialect=dialect, **opts)
+@t.overload
+def column(
+ col: str | Identifier,
+ table: t.Optional[str | Identifier] = None,
+ db: t.Optional[str | Identifier] = None,
+ catalog: t.Optional[str | Identifier] = None,
+ *,
+ fields: t.Collection[t.Union[str, Identifier]],
+ quoted: t.Optional[bool] = None,
+ copy: bool = True,
+) -> Dot:
+ pass
+
+
+@t.overload
def column(
col: str | Identifier,
table: t.Optional[str | Identifier] = None,
db: t.Optional[str | Identifier] = None,
catalog: t.Optional[str | Identifier] = None,
+ *,
+ fields: Lit[None] = None,
quoted: t.Optional[bool] = None,
+ copy: bool = True,
) -> Column:
+ pass
+
+
+def column(
+ col,
+ table=None,
+ db=None,
+ catalog=None,
+ *,
+ fields=None,
+ quoted=None,
+ copy=True,
+):
"""
Build a Column.
@@ -6234,18 +6418,24 @@ def column(
table: Table name.
db: Database name.
catalog: Catalog name.
+ fields: Additional fields using dots.
quoted: Whether to force quotes on the column's identifiers.
+ copy: Whether or not to copy identifiers if passed in.
Returns:
The new Column instance.
"""
- return Column(
- this=to_identifier(col, quoted=quoted),
- table=to_identifier(table, quoted=quoted),
- db=to_identifier(db, quoted=quoted),
- catalog=to_identifier(catalog, quoted=quoted),
+ this = Column(
+ this=to_identifier(col, quoted=quoted, copy=copy),
+ table=to_identifier(table, quoted=quoted, copy=copy),
+ db=to_identifier(db, quoted=quoted, copy=copy),
+ catalog=to_identifier(catalog, quoted=quoted, copy=copy),
)
+ if fields:
+ this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields)))
+ return this
+
def cast(expression: ExpOrStr, to: DATA_TYPE, **opts) -> Cast:
"""Cast an expression to a data type.
@@ -6333,10 +6523,10 @@ def var(name: t.Optional[ExpOrStr]) -> Var:
Example:
>>> repr(var('x'))
- '(VAR this: x)'
+ 'Var(this=x)'
>>> repr(var(column('x', table='y')))
- '(VAR this: x)'
+ 'Var(this=x)'
Args:
name: The name of the var or an expression who's name will become the var.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index b0e83d2..977185f 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -68,6 +68,7 @@ class Generator:
exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})",
exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})",
exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}",
+ exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}",
exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS",
exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
@@ -96,6 +97,7 @@ class Generator:
exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
+ exp.SqlReadWriteProperty: lambda self, e: e.name,
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.StabilityProperty: lambda self, e: e.name,
exp.TemporaryProperty: lambda self, e: f"TEMPORARY",
@@ -110,7 +112,8 @@ class Generator:
}
# Whether or not null ordering is supported in order by
- NULL_ORDERING_SUPPORTED = True
+ # True: Full Support, None: No support, False: No support in window specifications
+ NULL_ORDERING_SUPPORTED: t.Optional[bool] = True
# Whether or not locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported
LOCKING_READS_SUPPORTED = False
@@ -133,12 +136,6 @@ class Generator:
# Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
INTERVAL_ALLOWS_PLURAL_FORM = True
- # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
- TABLESAMPLE_WITH_METHOD = True
-
- # Whether or not to treat the number in TABLESAMPLE (50) as a percentage
- TABLESAMPLE_SIZE_IS_PERCENT = False
-
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
@@ -219,6 +216,18 @@ class Generator:
# Whether or not parentheses are required around the table sample's expression
TABLESAMPLE_REQUIRES_PARENS = True
+ # Whether or not a table sample clause's size needs to be followed by the ROWS keyword
+ TABLESAMPLE_SIZE_IS_ROWS = True
+
+ # The keyword(s) to use when generating a sample clause
+ TABLESAMPLE_KEYWORDS = "TABLESAMPLE"
+
+ # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
+ TABLESAMPLE_WITH_METHOD = True
+
+ # The keyword to use when specifying the seed of a sample clause
+ TABLESAMPLE_SEED_KEYWORD = "SEED"
+
# Whether or not COLLATE is a function instead of a binary operator
COLLATE_IS_FUNC = False
@@ -234,6 +243,27 @@ class Generator:
# Whether or not CONCAT requires >1 arguments
SUPPORTS_SINGLE_ARG_CONCAT = True
+ # Whether or not LAST_DAY function supports a date part argument
+ LAST_DAY_SUPPORTS_DATE_PART = True
+
+ # Whether or not named columns are allowed in table aliases
+ SUPPORTS_TABLE_ALIAS_COLUMNS = True
+
+ # Whether or not UNPIVOT aliases are Identifiers (False means they're Literals)
+ UNPIVOT_ALIASES_ARE_IDENTIFIERS = True
+
+ # What delimiter to use for separating JSON key/value pairs
+ JSON_KEY_VALUE_PAIR_SEP = ":"
+
+ # INSERT OVERWRITE TABLE x override
+ INSERT_OVERWRITE = " OVERWRITE TABLE"
+
+ # Whether or not the SELECT .. INTO syntax is used instead of CTAS
+ SUPPORTS_SELECT_INTO = False
+
+ # Whether or not UNLOGGED tables can be created
+ SUPPORTS_UNLOGGED_TABLES = False
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -252,15 +282,15 @@ class Generator:
}
TIME_PART_SINGULARS = {
- "microseconds": "microsecond",
- "seconds": "second",
- "minutes": "minute",
- "hours": "hour",
- "days": "day",
- "weeks": "week",
- "months": "month",
- "quarters": "quarter",
- "years": "year",
+ "MICROSECONDS": "MICROSECOND",
+ "SECONDS": "SECOND",
+ "MINUTES": "MINUTE",
+ "HOURS": "HOUR",
+ "DAYS": "DAY",
+ "WEEKS": "WEEK",
+ "MONTHS": "MONTH",
+ "QUARTERS": "QUARTER",
+ "YEARS": "YEAR",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@@ -272,6 +302,7 @@ class Generator:
PROPERTIES_LOCATION = {
exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE,
exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA,
exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME,
exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA,
exp.ChecksumProperty: exp.Properties.Location.POST_NAME,
@@ -323,6 +354,7 @@ class Generator:
exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA,
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
@@ -370,7 +402,7 @@ class Generator:
# Expressions that need to have all CTEs under them bubbled up to them
EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set()
- KEY_VALUE_DEFINITONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
+ KEY_VALUE_DEFINITIONS = (exp.Bracket, exp.EQ, exp.PropertyEQ, exp.Slice)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@@ -775,7 +807,7 @@ class Generator:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
- kind = self.sql(expression, "kind").upper()
+ kind = self.sql(expression, "kind")
properties = expression.args.get("properties")
properties_locs = self.locate_properties(properties) if properties else defaultdict()
@@ -868,7 +900,12 @@ class Generator:
return f"{shallow}{keyword} {this}"
def describe_sql(self, expression: exp.Describe) -> str:
- return f"DESCRIBE {self.sql(expression, 'this')}"
+ extended = " EXTENDED" if expression.args.get("extended") else ""
+ return f"DESCRIBE{extended} {self.sql(expression, 'this')}"
+
+ def heredoc_sql(self, expression: exp.Heredoc) -> str:
+ tag = self.sql(expression, "tag")
+ return f"${tag}${self.sql(expression, 'this')}${tag}$"
def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
with_ = self.sql(expression, "with")
@@ -895,6 +932,10 @@ class Generator:
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
+ if columns and not self.SUPPORTS_TABLE_ALIAS_COLUMNS:
+ columns = ""
+ self.unsupported("Named columns are not supported in table alias.")
+
if not alias and not self.dialect.UNNEST_COLUMN_ONLY:
alias = "_t"
@@ -1027,7 +1068,7 @@ class Generator:
def fetch_sql(self, expression: exp.Fetch) -> str:
direction = expression.args.get("direction")
- direction = f" {direction.upper()}" if direction else ""
+ direction = f" {direction}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
if expression.args.get("percent"):
@@ -1318,7 +1359,7 @@ class Generator:
if isinstance(expression.this, exp.Directory):
this = " OVERWRITE" if overwrite else " INTO"
else:
- this = " OVERWRITE TABLE" if overwrite else " INTO"
+ this = self.INSERT_OVERWRITE if overwrite else " INTO"
alternative = expression.args.get("alternative")
alternative = f" OR {alternative}" if alternative else ""
@@ -1365,10 +1406,10 @@ class Generator:
return f"KILL{kind}{this}"
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
- return expression.name.upper()
+ return expression.name
def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str:
- return expression.name.upper()
+ return expression.name
def onconflict_sql(self, expression: exp.OnConflict) -> str:
conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
@@ -1445,9 +1486,6 @@ class Generator:
pattern = f", PATTERN => {pattern}" if pattern else ""
file_format = f" (FILE_FORMAT => {file_format}{pattern})"
- index = self.sql(expression, "index")
- index = f" AT {index}" if index else ""
-
ordinality = expression.args.get("ordinality") or ""
if ordinality:
ordinality = f" WITH ORDINALITY{alias}"
@@ -1457,10 +1495,13 @@ class Generator:
if when:
table = f"{table} {when}"
- return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
+ return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
- self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ self,
+ expression: exp.TableSample,
+ sep: str = " AS ",
+ tablesample_keyword: t.Optional[str] = None,
) -> str:
if self.dialect.ALIAS_POST_TABLESAMPLE and expression.this and expression.this.alias:
table = expression.this.copy()
@@ -1472,30 +1513,30 @@ class Generator:
alias = ""
method = self.sql(expression, "method")
- method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
+ method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
denominator = self.sql(expression, "bucket_denominator")
field = self.sql(expression, "bucket_field")
field = f" ON {field}" if field else ""
bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else ""
- percent = self.sql(expression, "percent")
- percent = f"{percent} PERCENT" if percent else ""
- rows = self.sql(expression, "rows")
- rows = f"{rows} ROWS" if rows else ""
+ seed = self.sql(expression, "seed")
+ seed = f" {self.TABLESAMPLE_SEED_KEYWORD} ({seed})" if seed else ""
size = self.sql(expression, "size")
- if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
- size = f"{size} PERCENT"
+ if size and self.TABLESAMPLE_SIZE_IS_ROWS:
+ size = f"{size} ROWS"
- seed = self.sql(expression, "seed")
- seed = f" {seed_prefix} ({seed})" if seed else ""
- kind = expression.args.get("kind", "TABLESAMPLE")
+ percent = self.sql(expression, "percent")
+ if percent and not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT:
+ percent = f"{percent} PERCENT"
- expr = f"{bucket}{percent}{rows}{size}"
+ expr = f"{bucket}{percent}{size}"
if self.TABLESAMPLE_REQUIRES_PARENS:
expr = f"({expr})"
- return f"{this} {kind} {method}{expr}{seed}{alias}"
+ return (
+ f"{this} {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}{alias}"
+ )
def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
@@ -1513,8 +1554,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
- unpivot = expression.args.get("unpivot")
- direction = "UNPIVOT" if unpivot else "PIVOT"
+ direction = "UNPIVOT" if expression.unpivot else "PIVOT"
field = self.sql(expression, "field")
include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
@@ -1675,7 +1715,8 @@ class Generator:
if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))
- this_sql = self.sql(expression, "this")
+ this = expression.this
+ this_sql = self.sql(this)
if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
@@ -1685,6 +1726,9 @@ class Generator:
else:
on_sql = f"{space}ON {on_sql}"
elif not op_sql:
+ if isinstance(this, exp.Lateral) and this.args.get("cross_apply") is not None:
+ return f" {this_sql}"
+
return f", {this_sql}"
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
@@ -1695,6 +1739,19 @@ class Generator:
args = f"({args})" if len(args.split(",")) > 1 else args
return f"{args} {arrow_sep} {self.sql(expression, 'this')}"
+ def lateral_op(self, expression: exp.Lateral) -> str:
+ cross_apply = expression.args.get("cross_apply")
+
+ # https://www.mssqltips.com/sqlservertip/1958/sql-server-cross-apply-and-outer-apply/
+ if cross_apply is True:
+ op = "INNER JOIN "
+ elif cross_apply is False:
+ op = "LEFT JOIN "
+ else:
+ op = ""
+
+ return f"{op}LATERAL"
+
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
@@ -1708,7 +1765,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
- return f"LATERAL {this}{alias}"
+ return f"{self.lateral_op(expression)} {this}{alias}"
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
@@ -1805,7 +1862,8 @@ class Generator:
def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
- order = self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
+ siblings = "SIBLINGS " if expression.args.get("siblings") else ""
+ order = self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore
interpolated_values = [
f"{self.sql(named_expression, 'alias')} AS {self.sql(named_expression, 'this')}"
for named_expression in expression.args.get("interpolate") or []
@@ -1860,9 +1918,21 @@ class Generator:
# If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
- this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
- nulls_sort_change = ""
+ window = expression.find_ancestor(exp.Window, exp.Select)
+ if isinstance(window, exp.Window) and window.args.get("spec"):
+ self.unsupported(
+ f"'{nulls_sort_change.strip()}' translation not supported in window functions"
+ )
+ nulls_sort_change = ""
+ elif self.NULL_ORDERING_SUPPORTED is None:
+ if expression.this.is_int:
+ self.unsupported(
+ f"'{nulls_sort_change.strip()}' translation not supported with positional ordering"
+ )
+ else:
+ null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else ""
+ this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}"
+ nulls_sort_change = ""
with_fill = self.sql(expression, "with_fill")
with_fill = f" {with_fill}" if with_fill else ""
@@ -1961,10 +2031,14 @@ class Generator:
return [locks, self.sql(expression, "sample")]
def select_sql(self, expression: exp.Select) -> str:
+ into = expression.args.get("into")
+ if not self.SUPPORTS_SELECT_INTO and into:
+ into.pop()
+
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
- kind = self.sql(expression, "kind").upper()
+ kind = self.sql(expression, "kind")
limit = expression.args.get("limit")
top = (
self.limit_sql(limit, top=True)
@@ -2005,7 +2079,19 @@ class Generator:
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
- return self.prepend_ctes(expression, sql)
+
+ sql = self.prepend_ctes(expression, sql)
+
+ if not self.SUPPORTS_SELECT_INTO and into:
+ if into.args.get("temporary"):
+ table_kind = " TEMPORARY"
+ elif self.SUPPORTS_UNLOGGED_TABLES and into.args.get("unlogged"):
+ table_kind = " UNLOGGED"
+ else:
+ table_kind = ""
+ sql = f"CREATE{table_kind} TABLE {self.sql(into.this)} AS {sql}"
+
+ return sql
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
@@ -2266,29 +2352,35 @@ class Generator:
return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})"
def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
- return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
+ return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}"
def formatjson_sql(self, expression: exp.FormatJson) -> str:
return f"{self.sql(expression, 'this')} FORMAT JSON"
- def jsonobject_sql(self, expression: exp.JSONObject) -> str:
+ def jsonobject_sql(self, expression: exp.JSONObject | exp.JSONObjectAgg) -> str:
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
+
unique_keys = expression.args.get("unique_keys")
if unique_keys is not None:
unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
else:
unique_keys = ""
+
return_type = self.sql(expression, "return_type")
return_type = f" RETURNING {return_type}" if return_type else ""
encoding = self.sql(expression, "encoding")
encoding = f" ENCODING {encoding}" if encoding else ""
+
return self.func(
- "JSON_OBJECT",
+ "JSON_OBJECT" if isinstance(expression, exp.JSONObject) else "JSON_OBJECTAGG",
*expression.expressions,
suffix=f"{null_handling}{unique_keys}{return_type}{encoding})",
)
+ def jsonobjectagg_sql(self, expression: exp.JSONObjectAgg) -> str:
+ return self.jsonobject_sql(expression)
+
def jsonarray_sql(self, expression: exp.JSONArray) -> str:
null_handling = expression.args.get("null_handling")
null_handling = f" {null_handling}" if null_handling else ""
@@ -2385,7 +2477,7 @@ class Generator:
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
if not self.INTERVAL_ALLOWS_PLURAL_FORM:
- unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
+ unit = self.TIME_PART_SINGULARS.get(unit, unit)
unit = f" {unit}" if unit else ""
if self.SINGLE_STRING_INTERVAL:
@@ -2436,9 +2528,25 @@ class Generator:
alias = f" AS {alias}" if alias else ""
return f"{self.sql(expression, 'this')}{alias}"
+ def pivotalias_sql(self, expression: exp.PivotAlias) -> str:
+ alias = expression.args["alias"]
+ identifier_alias = isinstance(alias, exp.Identifier)
+
+ if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
+ alias.replace(exp.Literal.string(alias.output_name))
+ elif not identifier_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS:
+ alias.replace(exp.to_identifier(alias.output_name))
+
+ return self.alias_sql(expression)
+
def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
+ def atindex_sql(self, expression: exp.AtTimeZone) -> str:
+ this = self.sql(expression, "this")
+ index = self.sql(expression, "expression")
+ return f"{this} AT {index}"
+
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
@@ -2500,7 +2608,7 @@ class Generator:
return self.binary(expression, "COLLATE")
def command_sql(self, expression: exp.Command) -> str:
- return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
+ return f"{self.sql(expression, 'this')} {expression.text('expression').strip()}"
def comment_sql(self, expression: exp.Comment) -> str:
this = self.sql(expression, "this")
@@ -3102,6 +3210,47 @@ class Generator:
cond_for_null = arg.is_(exp.null())
return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.Array(expressions=[arg])))
+ def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
+ this = expression.this
+ if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME):
+ return self.sql(this)
+
+ return self.sql(exp.cast(this, "time"))
+
+ def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
+ this = expression.this
+ time_format = self.format_time(expression)
+
+ if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT):
+ return self.sql(
+ exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date")
+ )
+
+ if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE):
+ return self.sql(this)
+
+ return self.sql(exp.cast(this, "date"))
+
+ def unixdate_sql(self, expression: exp.UnixDate) -> str:
+ return self.sql(
+ exp.func(
+ "DATEDIFF",
+ expression.this,
+ exp.cast(exp.Literal.string("1970-01-01"), "date"),
+ "day",
+ )
+ )
+
+ def lastday_sql(self, expression: exp.LastDay) -> str:
+ if self.LAST_DAY_SUPPORTS_DATE_PART:
+ return self.function_fallback_sql(expression)
+
+ unit = expression.text("unit")
+ if unit and unit != "MONTH":
+ self.unsupported("Date parts are not supported in LAST_DAY.")
+
+ return self.func("LAST_DAY", expression.this)
+
def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index abcc10f..09bf201 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -129,13 +129,10 @@ def lineage(
if isinstance(column, int)
else next(
(select for select in scope.expression.selects if select.alias_or_name == column),
- exp.Star() if scope.expression.is_star else None,
+ exp.Star() if scope.expression.is_star else scope.expression,
)
)
- if not select:
- raise ValueError(f"Could not find {column} in {scope.expression}")
-
if isinstance(scope.expression, exp.Union):
upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
@@ -194,6 +191,8 @@ def lineage(
# if the select is a star add all scope sources as downstreams
if select.is_star:
for source in scope.sources.values():
+ if isinstance(source, Scope):
+ source = source.expression
node.downstream.append(Node(name=select.sql(), source=source, expression=source))
# Find all columns that went into creating this one to list their lineage nodes.
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 7b990f1..d0168d5 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -195,6 +195,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.StrPosition,
exp.TsOrDiToDi,
},
+ exp.DataType.Type.JSON: {
+ exp.ParseJSON,
+ },
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
@@ -275,6 +278,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
}
NESTED_TYPES = {
@@ -477,7 +481,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
@t.no_type_check
def _annotate_by_args(
- self, expression: E, *args: str, promote: bool = False, array: bool = False
+ self,
+ expression: E,
+ *args: str,
+ promote: bool = False,
+ array: bool = False,
+ struct: bool = False,
) -> E:
self._annotate_args(expression)
@@ -506,6 +515,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
),
)
+ if struct:
+ expressions = [
+ expr.type
+ if not expr.args.get("alias")
+ else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
+ for expr in expressions
+ ]
+
+ self._set_type(
+ expression,
+ exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
+ )
+
return expression
def _annotate_timeunit(
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 10ff13a..12c3b89 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -30,13 +30,18 @@ def pushdown_predicates(expression, dialect=None):
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
+ join_index = {
+ join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or [])
+ }
+
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
- pushdown(where.this, selected_sources, scope_ref_count, dialect)
+
+ pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index)
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
@@ -53,7 +58,7 @@ def pushdown_predicates(expression, dialect=None):
return expression
-def pushdown(condition, sources, scope_ref_count, dialect):
+def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
if not condition:
return
@@ -67,21 +72,28 @@ def pushdown(condition, sources, scope_ref_count, dialect):
)
if cnf_like:
- pushdown_cnf(predicates, sources, scope_ref_count)
+ pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index)
else:
pushdown_dnf(predicates, sources, scope_ref_count)
-def pushdown_cnf(predicates, scope, scope_ref_count):
+def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
+ join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
- predicate.replace(exp.true())
- node.on(predicate, copy=False)
- break
+ name = node.alias_or_name
+ predicate_tables = exp.column_table_names(predicate, name)
+
+ # Don't push the predicate if it references tables that appear in later joins
+ this_index = join_index[name]
+ if all(join_index.get(table, -1) < this_index for table in predicate_tables):
+ predicate.replace(exp.true())
+ node.on(predicate, copy=False)
+ break
if isinstance(node, exp.Select):
predicate.replace(exp.true())
inner_predicate = replace_aliases(node, predicate)
@@ -112,9 +124,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
conditions = {}
- # for every pushdown table, find all related conditions in all predicates
- # combine them with ORS
- # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
+ # pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
@@ -122,23 +132,9 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
if table not in nodes:
continue
- predicate_condition = None
-
- for column in predicate.find_all(exp.Column):
- if column.table == table:
- condition = column.find_ancestor(exp.Condition)
- predicate_condition = (
- exp.and_(predicate_condition, condition)
- if predicate_condition
- else condition
- )
-
- if predicate_condition:
- conditions[table] = (
- exp.or_(conditions[table], predicate_condition)
- if table in conditions
- else predicate_condition
- )
+ conditions[table] = (
+ exp.or_(conditions[table], predicate) if table in conditions else predicate
+ )
for name, node in nodes.items():
if name not in conditions:
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 4bc3bd2..e3aaebc 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -43,9 +43,8 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
parent_selections = referenced_columns.get(scope, {SELECT_ALL})
alias_count = source_column_alias_count.get(scope, 0)
- if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots):
- # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if
- # we select from a pivoted source in the parent scope.
+ # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
+ if scope.expression.args.get("distinct"):
parent_selections = {SELECT_ALL}
if isinstance(scope.expression, exp.Union):
@@ -78,7 +77,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True)
# Push the selected columns down to the next scope
for name, (node, source) in scope.selected_sources.items():
if isinstance(source, Scope):
- columns = selects.get(name) or set()
+ columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
referenced_columns[source].update(columns)
column_aliases = node.alias_column_names
diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py
index 5fdbde8..8d83b47 100644
--- a/sqlglot/optimizer/qualify.py
+++ b/sqlglot/optimizer/qualify.py
@@ -3,10 +3,11 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
-from sqlglot.dialects.dialect import DialectType
+from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
+ pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
qualify_columns as qualify_columns_func,
quote_identifiers as quote_identifiers_func,
validate_qualify_columns as validate_qualify_columns_func,
@@ -22,6 +23,7 @@ def qualify(
catalog: t.Optional[str] = None,
schema: t.Optional[dict | Schema] = None,
expand_alias_refs: bool = True,
+ expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
isolate_tables: bool = False,
qualify_columns: bool = True,
@@ -47,6 +49,9 @@ def qualify(
catalog: Default catalog name for tables.
schema: Schema to infer column names and types.
expand_alias_refs: Whether or not to expand references to aliases.
+ expand_stars: Whether or not to expand star queries. This is a necessary step
+ for most of the optimizer's rules to work; do not set to False unless you
+ know what you're doing!
infer_schema: Whether or not to infer the schema if missing.
isolate_tables: Whether or not to isolate table selects.
qualify_columns: Whether or not to qualify columns.
@@ -66,9 +71,16 @@ def qualify(
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)
+ if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN:
+ expression = pushdown_cte_alias_columns_func(expression)
+
if qualify_columns:
expression = qualify_columns_func(
- expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
+ expression,
+ schema,
+ expand_alias_refs=expand_alias_refs,
+ expand_stars=expand_stars,
+ infer_schema=infer_schema,
)
if quote_identifiers:
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 742cdf5..a6397ae 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -17,6 +17,7 @@ def qualify_columns(
expression: exp.Expression,
schema: t.Dict | Schema,
expand_alias_refs: bool = True,
+ expand_stars: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
"""
@@ -33,10 +34,16 @@ def qualify_columns(
expression: Expression to qualify.
schema: Database schema.
expand_alias_refs: Whether or not to expand references to aliases.
+ expand_stars: Whether or not to expand star queries. This is a necessary step
+ for most of the optimizer's rules to work; do not set to False unless you
+ know what you're doing!
infer_schema: Whether or not to infer the schema if missing.
Returns:
The qualified expression.
+
+ Notes:
+ - Currently only handles a single PIVOT or UNPIVOT operator
"""
schema = ensure_schema(schema)
infer_schema = schema.empty if infer_schema is None else infer_schema
@@ -57,7 +64,8 @@ def qualify_columns(
_expand_alias_refs(scope, resolver)
if not isinstance(scope.expression, exp.UDTF):
- _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
+ if expand_stars:
+ _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
qualify_outputs(scope)
_expand_group_by(scope)
@@ -68,21 +76,41 @@ def qualify_columns(
def validate_qualify_columns(expression: E) -> E:
"""Raise an `OptimizeError` if any columns aren't qualified"""
- unqualified_columns = []
+ all_unqualified_columns = []
for scope in traverse_scope(expression):
if isinstance(scope.expression, exp.Select):
- unqualified_columns.extend(scope.unqualified_columns)
+ unqualified_columns = scope.unqualified_columns
+
if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
column = scope.external_columns[0]
- raise OptimizeError(
- f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
- )
+ for_table = f" for table: '{column.table}'" if column.table else ""
+ raise OptimizeError(f"Column '{column}' could not be resolved{for_table}")
+
+ if unqualified_columns and scope.pivots and scope.pivots[0].unpivot:
+ # New columns produced by the UNPIVOT can't be qualified, but there may be columns
+ # under the UNPIVOT's IN clause that can and should be qualified. We recompute
+ # this list here to ensure those in the former category will be excluded.
+ unpivot_columns = set(_unpivot_columns(scope.pivots[0]))
+ unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns]
+
+ all_unqualified_columns.extend(unqualified_columns)
+
+ if all_unqualified_columns:
+ raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}")
- if unqualified_columns:
- raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
return expression
+def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]:
+ name_column = []
+ field = unpivot.args.get("field")
+ if isinstance(field, exp.In) and isinstance(field.this, exp.Column):
+ name_column.append(field.this)
+
+ value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column))
+ return itertools.chain(name_column, value_columns)
+
+
def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
"""
Remove table column aliases.
@@ -216,6 +244,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
+
scope.clear_cache()
@@ -353,18 +382,25 @@ def _expand_stars(
replace_columns: t.Dict[int, t.Dict[str, str]] = {}
coalesced_columns = set()
- # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
- pivot_columns = None
pivot_output_columns = None
- pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
+ pivot_exclude_columns = None
- has_pivoted_source = pivot and not pivot.args.get("unpivot")
- if pivot and has_pivoted_source:
- pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
+ pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
+ if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names:
+ if pivot.unpivot:
+ pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)]
+
+ field = pivot.args.get("field")
+ if isinstance(field, exp.In):
+ pivot_exclude_columns = {
+ c.output_name for e in field.expressions for c in e.find_all(exp.Column)
+ }
+ else:
+ pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column))
- pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
- if not pivot_output_columns:
- pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
+ pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])]
+ if not pivot_output_columns:
+ pivot_output_columns = [c.alias_or_name for c in pivot.expressions]
for expression in scope.expression.selects:
if isinstance(expression, exp.Star):
@@ -384,47 +420,54 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")
columns = resolver.get_source_columns(table, only_visible=True)
+ columns = columns or scope.outer_column_list
if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
- if columns and "*" not in columns:
- table_id = id(table)
- columns_to_exclude = except_columns.get(table_id) or set()
+ if not columns or "*" in columns:
+ return
+
+ table_id = id(table)
+ columns_to_exclude = except_columns.get(table_id) or set()
- if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
- implicit_columns = [col for col in columns if col not in pivot_columns]
+ if pivot:
+ if pivot_output_columns and pivot_exclude_columns:
+ pivot_columns = [c for c in columns if c not in pivot_exclude_columns]
+ pivot_columns.extend(pivot_output_columns)
+ else:
+ pivot_columns = pivot.alias_column_names
+
+ if pivot_columns:
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
- for name in implicit_columns + pivot_output_columns
+ for name in pivot_columns
if name not in columns_to_exclude
)
continue
- for name in columns:
- if name in using_column_tables and table in using_column_tables[name]:
- if name in coalesced_columns:
- continue
-
- coalesced_columns.add(name)
- tables = using_column_tables[name]
- coalesce = [exp.column(name, table=table) for table in tables]
-
- new_selections.append(
- alias(
- exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
- alias=name,
- copy=False,
- )
- )
- elif name not in columns_to_exclude:
- alias_ = replace_columns.get(table_id, {}).get(name, name)
- column = exp.column(name, table=table)
- new_selections.append(
- alias(column, alias_, copy=False) if alias_ != name else column
+ for name in columns:
+ if name in using_column_tables and table in using_column_tables[name]:
+ if name in coalesced_columns:
+ continue
+
+ coalesced_columns.add(name)
+ tables = using_column_tables[name]
+ coalesce = [exp.column(name, table=table) for table in tables]
+
+ new_selections.append(
+ alias(
+ exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
+ alias=name,
+ copy=False,
)
- else:
- return
+ )
+ elif name not in columns_to_exclude:
+ alias_ = replace_columns.get(table_id, {}).get(name, name)
+ column = exp.column(name, table=table)
+ new_selections.append(
+ alias(column, alias_, copy=False) if alias_ != name else column
+ )
# Ensures we don't overwrite the initial selections with an empty list
if new_selections:
@@ -472,6 +515,9 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
):
+ if selection is None:
+ break
+
if isinstance(selection, exp.Subquery):
if not selection.output_name:
selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
@@ -495,6 +541,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
)
+def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
+ """
+ Pushes down the CTE alias columns into the projection,
+
+ This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
+ >>> pushdown_cte_alias_columns(expression).sql()
+ 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
+
+ Args:
+ expression: Expression to pushdown.
+
+ Returns:
+ The expression with the CTE aliases pushed down into the projection.
+ """
+ for cte in expression.find_all(exp.CTE):
+ if cte.alias_column_names:
+ new_expressions = []
+ for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
+ if isinstance(projection, exp.Alias):
+ projection.set("alias", _alias)
+ else:
+ projection = alias(projection, alias=_alias)
+ new_expressions.append(projection)
+ cte.this.set("expressions", new_expressions)
+
+ return expression
+
+
class Resolver:
"""
Helper for resolving columns.
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 57ecabe..e0fe641 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -72,11 +72,15 @@ def qualify_tables(
if not source.args.get("catalog") and source.args.get("db"):
source.set("catalog", catalog)
+ pivots = pivots = source.args.get("pivots")
if not source.alias:
+ # Don't add the pivot's alias to the pivoted table, use the table's name instead
+ if pivots and pivots[0].alias == name:
+ name = source.name
+
# Mutates the source by attaching an alias to it
alias(source, name or source.name or next_alias_name(), copy=False, table=True)
- pivots = source.args.get("pivots")
if pivots and not pivots[0].alias:
pivots[0].set(
"alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index d34857d..a3f08d5 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -539,11 +539,23 @@ def _traverse_union(scope):
# The last scope to be yield should be the top most scope
left = None
- for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
+ for left in _traverse_scope(
+ scope.branch(
+ scope.expression.left,
+ outer_column_list=scope.outer_column_list,
+ scope_type=ScopeType.UNION,
+ )
+ ):
yield left
right = None
- for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
+ for right in _traverse_scope(
+ scope.branch(
+ scope.expression.right,
+ outer_column_list=scope.outer_column_list,
+ scope_type=ScopeType.UNION,
+ )
+ ):
yield right
scope.union_scopes = [left, right]
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index f53023c..25d4e75 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -100,6 +100,7 @@ def simplify(
node = simplify_parens(node)
node = simplify_datetrunc(node, dialect)
node = sort_comparison(node)
+ node = simplify_startswith(node)
if root:
expression.replace(node)
@@ -776,6 +777,26 @@ def simplify_conditionals(expression):
return expression
+def simplify_startswith(expression: exp.Expression) -> exp.Expression:
+ """
+ Reduces a prefix check to either TRUE or FALSE if both the string and the
+ prefix are statically known.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
+ 'TRUE'
+ """
+ if (
+ isinstance(expression, exp.StartsWith)
+ and expression.this.is_string
+ and expression.expression.is_string
+ ):
+ return exp.convert(expression.name.startswith(expression.expression.name))
+
+ return expression
+
+
DateRange = t.Tuple[datetime.date, datetime.date]
@@ -1160,7 +1181,7 @@ def gen(expression: t.Any) -> str:
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
- exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
+ exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 311c43d..790ee0d 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -12,6 +12,8 @@ from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import TrieResult, in_trie, new_trie
if t.TYPE_CHECKING:
+ from typing_extensions import Literal
+
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
@@ -193,6 +195,7 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATETIME64,
TokenType.DATE,
+ TokenType.DATE32,
TokenType.INT4RANGE,
TokenType.INT4MULTIRANGE,
TokenType.INT8RANGE,
@@ -232,6 +235,8 @@ class Parser(metaclass=_Parser):
TokenType.INET,
TokenType.IPADDRESS,
TokenType.IPPREFIX,
+ TokenType.IPV4,
+ TokenType.IPV6,
TokenType.UNKNOWN,
TokenType.NULL,
*ENUM_TYPE_TOKENS,
@@ -669,6 +674,7 @@ class Parser(metaclass=_Parser):
PROPERTY_PARSERS: t.Dict[str, t.Callable] = {
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
+ "AUTO": lambda self: self._parse_auto_property(),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
"CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs),
@@ -680,6 +686,7 @@ class Parser(metaclass=_Parser):
exp.CollateProperty, **kwargs
),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
+ "CONTAINS": lambda self: self._parse_contains_property(),
"COPY": lambda self: self._parse_copy_property(),
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
"DEFINER": lambda self: self._parse_definer(),
@@ -710,6 +717,7 @@ class Parser(metaclass=_Parser):
"LOG": lambda self, **kwargs: self._parse_log(**kwargs),
"MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty),
"MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs),
+ "MODIFIES": lambda self: self._parse_modifies_property(),
"MULTISET": lambda self: self.expression(exp.SetProperty, multi=True),
"NO": lambda self: self._parse_no_property(),
"ON": lambda self: self._parse_on_property(),
@@ -721,6 +729,7 @@ class Parser(metaclass=_Parser):
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
"PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
"RANGE": lambda self: self._parse_dict_range(this="RANGE"),
+ "READS": lambda self: self._parse_reads_property(),
"REMOTE": lambda self: self._parse_remote_with_connection(),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
@@ -841,6 +850,7 @@ class Parser(metaclass=_Parser):
"DECODE": lambda self: self._parse_decode(),
"EXTRACT": lambda self: self._parse_extract(),
"JSON_OBJECT": lambda self: self._parse_json_object(),
+ "JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True),
"JSON_TABLE": lambda self: self._parse_json_table(),
"MATCH": lambda self: self._parse_match_against(),
"OPENJSON": lambda self: self._parse_open_json(),
@@ -925,6 +935,8 @@ class Parser(metaclass=_Parser):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}
WINDOW_SIDES = {"FOLLOWING", "PRECEDING"}
+ JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS}
+
FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT}
ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}
@@ -954,6 +966,9 @@ class Parser(metaclass=_Parser):
# Whether the TRIM function expects the characters to trim as its first argument
TRIM_PATTERN_FIRST = False
+ # Whether or not string aliases are supported `SELECT COUNT(*) 'count'`
+ STRING_ALIASES = False
+
# Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand)
MODIFIERS_ATTACHED_TO_UNION = True
UNION_MODIFIERS = {"order", "limit", "offset"}
@@ -1193,7 +1208,9 @@ class Parser(metaclass=_Parser):
self._advance(index - self._index)
def _parse_command(self) -> exp.Command:
- return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
+ return self.expression(
+ exp.Command, this=self._prev.text.upper(), expression=self._parse_string()
+ )
def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
start = self._prev
@@ -1353,26 +1370,27 @@ class Parser(metaclass=_Parser):
# exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature)
extend_props(self._parse_properties())
- self._match(TokenType.ALIAS)
-
- if self._match(TokenType.COMMAND):
- expression = self._parse_as_command(self._prev)
- else:
- begin = self._match(TokenType.BEGIN)
- return_ = self._match_text_seq("RETURN")
+ expression = self._match(TokenType.ALIAS) and self._parse_heredoc()
- if self._match(TokenType.STRING, advance=False):
- # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
- # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
- expression = self._parse_string()
- extend_props(self._parse_properties())
+ if not expression:
+ if self._match(TokenType.COMMAND):
+ expression = self._parse_as_command(self._prev)
else:
- expression = self._parse_statement()
+ begin = self._match(TokenType.BEGIN)
+ return_ = self._match_text_seq("RETURN")
+
+ if self._match(TokenType.STRING, advance=False):
+ # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property
+ # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement
+ expression = self._parse_string()
+ extend_props(self._parse_properties())
+ else:
+ expression = self._parse_statement()
- end = self._match_text_seq("END")
+ end = self._match_text_seq("END")
- if return_:
- expression = self.expression(exp.Return, this=expression)
+ if return_:
+ expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
@@ -1426,7 +1444,7 @@ class Parser(metaclass=_Parser):
exp.Create,
comments=comments,
this=this,
- kind=create_token.text,
+ kind=create_token.text.upper(),
replace=replace,
unique=unique,
expression=expression,
@@ -1849,9 +1867,21 @@ class Parser(metaclass=_Parser):
return self.expression(exp.WithDataProperty, no=no, statistics=statistics)
- def _parse_no_property(self) -> t.Optional[exp.NoPrimaryIndexProperty]:
+ def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
+ if self._match_text_seq("SQL"):
+ return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL")
+ return None
+
+ def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
+ if self._match_text_seq("SQL", "DATA"):
+ return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA")
+ return None
+
+ def _parse_no_property(self) -> t.Optional[exp.Expression]:
if self._match_text_seq("PRIMARY", "INDEX"):
return exp.NoPrimaryIndexProperty()
+ if self._match_text_seq("SQL"):
+ return self.expression(exp.SqlReadWriteProperty, this="NO SQL")
return None
def _parse_on_property(self) -> t.Optional[exp.Expression]:
@@ -1861,6 +1891,11 @@ class Parser(metaclass=_Parser):
return exp.OnCommitProperty(delete=True)
return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var()))
+ def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]:
+ if self._match_text_seq("SQL", "DATA"):
+ return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA")
+ return None
+
def _parse_distkey(self) -> exp.DistKeyProperty:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
@@ -1920,10 +1955,13 @@ class Parser(metaclass=_Parser):
def _parse_describe(self) -> exp.Describe:
kind = self._match_set(self.CREATABLES) and self._prev.text
+ extended = self._match_text_seq("EXTENDED")
this = self._parse_table(schema=True)
properties = self._parse_properties()
expressions = properties.expressions if properties else None
- return self.expression(exp.Describe, this=this, kind=kind, expressions=expressions)
+ return self.expression(
+ exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions
+ )
def _parse_insert(self) -> exp.Insert:
comments = ensure_list(self._prev_comments)
@@ -2164,13 +2202,13 @@ class Parser(metaclass=_Parser):
def _parse_value(self) -> exp.Tuple:
if self._match(TokenType.L_PAREN):
- expressions = self._parse_csv(self._parse_conjunction)
+ expressions = self._parse_csv(self._parse_expression)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=expressions)
# In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
# https://prestodb.io/docs/current/sql/values.html
- return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
+ return self.expression(exp.Tuple, expressions=[self._parse_expression()])
def _parse_projections(self) -> t.List[exp.Expression]:
return self._parse_expressions()
@@ -2212,7 +2250,7 @@ class Parser(metaclass=_Parser):
kind = (
self._match(TokenType.ALIAS)
and self._match_texts(("STRUCT", "VALUE"))
- and self._prev.text
+ and self._prev.text.upper()
)
if distinct:
@@ -2261,7 +2299,7 @@ class Parser(metaclass=_Parser):
if table
else self._parse_select(nested=True, parse_set_operation=False)
)
- this = self._parse_set_operations(self._parse_query_modifiers(this))
+ this = self._parse_query_modifiers(self._parse_set_operations(this))
self._match_r_paren()
@@ -2304,7 +2342,7 @@ class Parser(metaclass=_Parser):
)
def _parse_cte(self) -> exp.CTE:
- alias = self._parse_table_alias()
+ alias = self._parse_table_alias(self.ID_VAR_TOKENS)
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
@@ -2490,13 +2528,14 @@ class Parser(metaclass=_Parser):
)
def _parse_lateral(self) -> t.Optional[exp.Lateral]:
- outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
+ if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY):
+ cross_apply = False
- if outer_apply or cross_apply:
+ if cross_apply is not None:
this = self._parse_select(table=True)
view = None
- outer = not cross_apply
+ outer = None
elif self._match(TokenType.LATERAL):
this = self._parse_select(table=True)
view = self._match(TokenType.VIEW)
@@ -2529,7 +2568,14 @@ class Parser(metaclass=_Parser):
else:
table_alias = self._parse_table_alias()
- return self.expression(exp.Lateral, this=this, view=view, outer=outer, alias=table_alias)
+ return self.expression(
+ exp.Lateral,
+ this=this,
+ view=view,
+ outer=outer,
+ alias=table_alias,
+ cross_apply=cross_apply,
+ )
def _parse_join_parts(
self,
@@ -2563,9 +2609,6 @@ class Parser(metaclass=_Parser):
if not skip_join_token and not join and not outer_apply and not cross_apply:
return None
- if outer_apply:
- side = Token(TokenType.LEFT, "LEFT")
-
kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)}
if method:
@@ -2755,8 +2798,10 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
- if self._match_text_seq("AT"):
- this.set("index", self._parse_id_var())
+ if isinstance(this, exp.Table) and self._match_text_seq("AT"):
+ return self.expression(
+ exp.AtIndex, this=this.to_column(copy=False), expression=self._parse_id_var()
+ )
this.set("hints", self._parse_table_hints())
@@ -2865,15 +2910,10 @@ class Parser(metaclass=_Parser):
bucket_denominator = None
bucket_field = None
percent = None
- rows = None
size = None
seed = None
- kind = (
- self._prev.text if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE"
- )
- method = self._parse_var(tokens=(TokenType.ROW,))
-
+ method = self._parse_var(tokens=(TokenType.ROW,), upper=True)
matched_l_paren = self._match(TokenType.L_PAREN)
if self.TABLESAMPLE_CSV:
@@ -2895,16 +2935,16 @@ class Parser(metaclass=_Parser):
bucket_field = self._parse_field()
elif self._match_set((TokenType.PERCENT, TokenType.MOD)):
percent = num
- elif self._match(TokenType.ROWS):
- rows = num
- elif num:
+ elif self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT:
size = num
+ else:
+ percent = num
if matched_l_paren:
self._match_r_paren()
if self._match(TokenType.L_PAREN):
- method = self._parse_var()
+ method = self._parse_var(upper=True)
seed = self._match(TokenType.COMMA) and self._parse_number()
self._match_r_paren()
elif self._match_texts(("SEED", "REPEATABLE")):
@@ -2918,10 +2958,8 @@ class Parser(metaclass=_Parser):
bucket_denominator=bucket_denominator,
bucket_field=bucket_field,
percent=percent,
- rows=rows,
size=size,
seed=seed,
- kind=kind,
)
def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
@@ -2946,6 +2984,27 @@ class Parser(metaclass=_Parser):
exp.Pivot, this=this, expressions=expressions, using=using, group=group
)
+ def _parse_pivot_in(self) -> exp.In:
+ def _parse_aliased_expression() -> t.Optional[exp.Expression]:
+ this = self._parse_conjunction()
+
+ self._match(TokenType.ALIAS)
+ alias = self._parse_field()
+ if alias:
+ return self.expression(exp.PivotAlias, this=this, alias=alias)
+
+ return this
+
+ value = self._parse_column()
+
+ if not self._match_pair(TokenType.IN, TokenType.L_PAREN):
+ self.raise_error("Expecting IN (")
+
+ aliased_expressions = self._parse_csv(_parse_aliased_expression)
+
+ self._match_r_paren()
+ return self.expression(exp.In, this=value, expressions=aliased_expressions)
+
def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
include_nulls = None
@@ -2964,7 +3023,6 @@ class Parser(metaclass=_Parser):
return None
expressions = []
- field = None
if not self._match(TokenType.L_PAREN):
self._retreat(index)
@@ -2981,12 +3039,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.FOR):
self.raise_error("Expecting FOR")
- value = self._parse_column()
-
- if not self._match(TokenType.IN):
- self.raise_error("Expecting IN")
-
- field = self._parse_in(value, alias=True)
+ field = self._parse_pivot_in()
self._match_r_paren()
@@ -3132,14 +3185,19 @@ class Parser(metaclass=_Parser):
def _parse_order(
self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
) -> t.Optional[exp.Expression]:
+ siblings = None
if not skip_order_token and not self._match(TokenType.ORDER_BY):
- return this
+ if not self._match(TokenType.ORDER_SIBLINGS_BY):
+ return this
+
+ siblings = True
return self.expression(
exp.Order,
this=this,
expressions=self._parse_csv(self._parse_ordered),
interpolate=self._parse_interpolate(),
+ siblings=siblings,
)
def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]:
@@ -3213,7 +3271,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
- direction = self._prev.text if direction else "FIRST"
+ direction = self._prev.text.upper() if direction else "FIRST"
count = self._parse_field(tokens=self.FETCH_TOKENS)
percent = self._match(TokenType.PERCENT)
@@ -3398,10 +3456,10 @@ class Parser(metaclass=_Parser):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
- def _parse_interval(self) -> t.Optional[exp.Interval]:
+ def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Interval]:
index = self._index
- if not self._match(TokenType.INTERVAL):
+ if not self._match(TokenType.INTERVAL) and match_interval:
return None
if self._match(TokenType.STRING, advance=False):
@@ -3409,11 +3467,19 @@ class Parser(metaclass=_Parser):
else:
this = self._parse_term()
- if not this:
+ if not this or (
+ isinstance(this, exp.Column)
+ and not this.table
+ and not this.this.quoted
+ and this.name.upper() == "IS"
+ ):
self._retreat(index)
return None
- unit = self._parse_function() or self._parse_var(any_token=True)
+ unit = self._parse_function() or (
+ not self._match(TokenType.ALIAS, advance=False)
+ and self._parse_var(any_token=True, upper=True)
+ )
# Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse
# each INTERVAL expression into this canonical form so it's easy to transpile
@@ -3429,7 +3495,7 @@ class Parser(metaclass=_Parser):
self._retreat(self._index - 1)
this = exp.Literal.string(parts[0])
- unit = self.expression(exp.Var, this=parts[1])
+ unit = self.expression(exp.Var, this=parts[1].upper())
return self.expression(exp.Interval, this=this, unit=unit)
@@ -3489,6 +3555,12 @@ class Parser(metaclass=_Parser):
def _parse_type(self, parse_interval: bool = True) -> t.Optional[exp.Expression]:
interval = parse_interval and self._parse_interval()
if interval:
+ # Convert INTERVAL 'val_1' unit_1 ... 'val_n' unit_n into a sum of intervals
+ while self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False):
+ interval = self.expression( # type: ignore
+ exp.Add, this=interval, expression=self._parse_interval(match_interval=False)
+ )
+
return interval
index = self._index
@@ -3552,10 +3624,10 @@ class Parser(metaclass=_Parser):
type_token = self._prev.token_type
if type_token == TokenType.PSEUDO_TYPE:
- return self.expression(exp.PseudoType, this=self._prev.text)
+ return self.expression(exp.PseudoType, this=self._prev.text.upper())
if type_token == TokenType.OBJECT_IDENTIFIER:
- return self.expression(exp.ObjectIdentifier, this=self._prev.text)
+ return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper())
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token in self.STRUCT_TYPE_TOKENS
@@ -3587,7 +3659,7 @@ class Parser(metaclass=_Parser):
if nested and self._match(TokenType.LT):
if is_struct:
- expressions = self._parse_csv(self._parse_struct_types)
+ expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True))
else:
expressions = self._parse_csv(
lambda: self._parse_types(
@@ -3662,10 +3734,19 @@ class Parser(metaclass=_Parser):
return this
- def _parse_struct_types(self) -> t.Optional[exp.Expression]:
+ def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]:
+ index = self._index
this = self._parse_type(parse_interval=False) or self._parse_id_var()
self._match(TokenType.COLON)
- return self._parse_column_def(this)
+ column_def = self._parse_column_def(this)
+
+ if type_required and (
+ (isinstance(this, exp.Column) and this.this is column_def) or this is column_def
+ ):
+ self._retreat(index)
+ return self._parse_types()
+
+ return column_def
def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_text_seq("AT", "TIME", "ZONE"):
@@ -4025,6 +4106,12 @@ class Parser(metaclass=_Parser):
return exp.AutoIncrementColumnConstraint()
+ def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]:
+ if not self._match_text_seq("REFRESH"):
+ self._retreat(self._index - 1)
+ return None
+ return self.expression(exp.AutoRefreshProperty, this=self._parse_var(upper=True))
+
def _parse_compress(self) -> exp.CompressColumnConstraint:
if self._match(TokenType.L_PAREN, advance=False):
return self.expression(
@@ -4230,8 +4317,10 @@ class Parser(metaclass=_Parser):
def _parse_primary_key_part(self) -> t.Optional[exp.Expression]:
return self._parse_field()
- def _parse_period_for_system_time(self) -> exp.PeriodForSystemTimeConstraint:
- self._match(TokenType.TIMESTAMP_SNAPSHOT)
+ def _parse_period_for_system_time(self) -> t.Optional[exp.PeriodForSystemTimeConstraint]:
+ if not self._match(TokenType.TIMESTAMP_SNAPSHOT):
+ self._retreat(self._index - 1)
+ return None
id_vars = self._parse_wrapped_id_vars()
return self.expression(
@@ -4257,22 +4346,17 @@ class Parser(metaclass=_Parser):
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
+ def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
+ return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))
+
def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this
bracket_kind = self._prev.token_type
-
- if self._match(TokenType.COLON):
- expressions: t.List[exp.Expression] = [
- self.expression(exp.Slice, expression=self._parse_conjunction())
- ]
- else:
- expressions = self._parse_csv(
- lambda: self._parse_slice(
- self._parse_alias(self._parse_conjunction(), explicit=True)
- )
- )
+ expressions = self._parse_csv(
+ lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE)
+ )
if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
self.raise_error("Expected ]")
@@ -4313,7 +4397,10 @@ class Parser(metaclass=_Parser):
default = self._parse_conjunction()
if not self._match(TokenType.END):
- self.raise_error("Expected END after CASE", self._prev)
+ if isinstance(default, exp.Interval) and default.this.sql().upper() == "END":
+ default = exp.column("interval")
+ else:
+ self.raise_error("Expected END after CASE", self._prev)
return self._parse_window(
self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default)
@@ -4514,7 +4601,7 @@ class Parser(metaclass=_Parser):
def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]:
self._match_text_seq("KEY")
key = self._parse_column()
- self._match_set((TokenType.COLON, TokenType.COMMA))
+ self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS)
self._match_text_seq("VALUE")
value = self._parse_bitwise()
@@ -4536,7 +4623,15 @@ class Parser(metaclass=_Parser):
return None
- def _parse_json_object(self) -> exp.JSONObject:
+ @t.overload
+ def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
+ ...
+
+ @t.overload
+ def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
+ ...
+
+ def _parse_json_object(self, agg=False):
star = self._parse_star()
expressions = (
[star]
@@ -4559,7 +4654,7 @@ class Parser(metaclass=_Parser):
encoding = self._match_text_seq("ENCODING") and self._parse_var()
return self.expression(
- exp.JSONObject,
+ exp.JSONObjectAgg if agg else exp.JSONObject,
expressions=expressions,
null_handling=null_handling,
unique_keys=unique_keys,
@@ -4873,10 +4968,17 @@ class Parser(metaclass=_Parser):
self._match_r_paren(aliases)
return aliases
- alias = self._parse_id_var(any_token)
+ alias = self._parse_id_var(any_token) or (
+ self.STRING_ALIASES and self._parse_string_as_identifier()
+ )
if alias:
- return self.expression(exp.Alias, comments=comments, this=this, alias=alias)
+ this = self.expression(exp.Alias, comments=comments, this=this, alias=alias)
+
+ # Moves the comment next to the alias in `expr /* comment */ AS alias`
+ if not this.comments and this.this.comments:
+ this.comments = this.this.comments
+ this.this.comments = None
return this
@@ -4915,14 +5017,19 @@ class Parser(metaclass=_Parser):
return self._parse_placeholder()
def _parse_var(
- self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None
+ self,
+ any_token: bool = False,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ upper: bool = False,
) -> t.Optional[exp.Expression]:
if (
(any_token and self._advance_any())
or self._match(TokenType.VAR)
or (self._match_set(tokens) if tokens else False)
):
- return self.expression(exp.Var, this=self._prev.text)
+ return self.expression(
+ exp.Var, this=self._prev.text.upper() if upper else self._prev.text
+ )
return self._parse_placeholder()
def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]:
@@ -5418,6 +5525,42 @@ class Parser(metaclass=_Parser):
condition=condition,
)
+ def _parse_heredoc(self) -> t.Optional[exp.Heredoc]:
+ if self._match(TokenType.HEREDOC_STRING):
+ return self.expression(exp.Heredoc, this=self._prev.text)
+
+ if not self._match_text_seq("$"):
+ return None
+
+ tags = ["$"]
+ tag_text = None
+
+ if self._is_connected():
+ self._advance()
+ tags.append(self._prev.text.upper())
+ else:
+ self.raise_error("No closing $ found")
+
+ if tags[-1] != "$":
+ if self._is_connected() and self._match_text_seq("$"):
+ tag_text = tags[-1]
+ tags.append("$")
+ else:
+ self.raise_error("No closing $ found")
+
+ heredoc_start = self._curr
+
+ while self._curr:
+ if self._match_text_seq(*tags, advance=False):
+ this = self._find_sql(heredoc_start, self._prev)
+ self._advance(len(tags))
+ return self.expression(exp.Heredoc, this=this, tag=tag_text)
+
+ self._advance()
+
+ self.raise_error(f"No closing {''.join(tags)} found")
+ return None
+
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index 54c08dd..8acd89f 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -215,12 +215,13 @@ class MappingSchema(AbstractMappingSchema, Schema):
normalize: bool = True,
) -> None:
self.dialect = dialect
- self.visible = visible or {}
+ self.visible = {} if visible is None else visible
self.normalize = normalize
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
self._depth = 0
+ schema = {} if schema is None else schema
- super().__init__(self._normalize(schema or {}))
+ super().__init__(self._normalize(schema) if self.normalize else schema)
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index de9d4c4..d8fb98b 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -147,6 +147,7 @@ class TokenType(AutoName):
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()
+ DATE32 = auto()
INT4RANGE = auto()
INT4MULTIRANGE = auto()
INT8RANGE = auto()
@@ -182,6 +183,8 @@ class TokenType(AutoName):
INET = auto()
IPADDRESS = auto()
IPPREFIX = auto()
+ IPV4 = auto()
+ IPV6 = auto()
ENUM = auto()
ENUM8 = auto()
ENUM16 = auto()
@@ -296,6 +299,7 @@ class TokenType(AutoName):
ON = auto()
OPERATOR = auto()
ORDER_BY = auto()
+ ORDER_SIBLINGS_BY = auto()
ORDERED = auto()
ORDINALITY = auto()
OUTER = auto()
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 03acc2b..0da65b5 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -255,7 +255,7 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp
if not arrays:
if expression.args.get("from"):
- expression.join(series, copy=False)
+ expression.join(series, copy=False, join_type="CROSS")
else:
expression.from_(series, copy=False)