summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-26 17:21:54 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-26 17:21:54 +0000
commitc03ba18c491e52cc85d8aae1825dd9e0b4f75e32 (patch)
treef76d58b50900be4bfd2dc15f0ec38d1a70d8417b /sqlglot
parentReleasing debian version 18.13.0-1. (diff)
downloadsqlglot-c03ba18c491e52cc85d8aae1825dd9e0b4f75e32.tar.xz
sqlglot-c03ba18c491e52cc85d8aae1825dd9e0b4f75e32.zip
Merging upstream version 18.17.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__main__.py6
-rw-r--r--sqlglot/dataframe/sql/functions.py6
-rw-r--r--sqlglot/dialects/bigquery.py13
-rw-r--r--sqlglot/dialects/clickhouse.py31
-rw-r--r--sqlglot/dialects/dialect.py56
-rw-r--r--sqlglot/dialects/duckdb.py20
-rw-r--r--sqlglot/dialects/hive.py10
-rw-r--r--sqlglot/dialects/mysql.py13
-rw-r--r--sqlglot/dialects/postgres.py3
-rw-r--r--sqlglot/dialects/presto.py23
-rw-r--r--sqlglot/dialects/redshift.py3
-rw-r--r--sqlglot/dialects/snowflake.py37
-rw-r--r--sqlglot/dialects/teradata.py5
-rw-r--r--sqlglot/dialects/tsql.py11
-rw-r--r--sqlglot/expressions.py80
-rw-r--r--sqlglot/generator.py55
-rw-r--r--sqlglot/lineage.py36
-rw-r--r--sqlglot/optimizer/annotate_types.py25
-rw-r--r--sqlglot/optimizer/canonicalize.py6
-rw-r--r--sqlglot/optimizer/simplify.py35
-rw-r--r--sqlglot/parser.py66
-rw-r--r--sqlglot/time.py603
-rw-r--r--sqlglot/tokens.py4
-rw-r--r--sqlglot/transforms.py17
24 files changed, 1058 insertions, 106 deletions
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index f3433d3..4a2820b 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -58,6 +58,12 @@ parser.add_argument(
default="IMMEDIATE",
help="IGNORE, WARN, RAISE, IMMEDIATE (default)",
)
+parser.add_argument(
+ "--version",
+ action="version",
+ version=sqlglot.__version__,
+ help="Display the SQLGlot version",
+)
args = parser.parse_args()
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index d98feee..a424ea4 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -84,11 +84,11 @@ def min(col: ColumnOrName) -> Column:
def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MAX_BY", ord)
+ return Column.invoke_expression_over_column(col, expression.ArgMax, expression=ord)
def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "MIN_BY", ord)
+ return Column.invoke_expression_over_column(col, expression.ArgMin, expression=ord)
def count(col: ColumnOrName) -> Column:
@@ -1113,7 +1113,7 @@ def reverse(col: ColumnOrName) -> Column:
def flatten(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "FLATTEN")
+ return Column.invoke_expression_over_column(col, expression.Flatten)
def map_keys(col: ColumnOrName) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 7f69dd9..51baba2 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
+ arg_max_or_min_no_count,
binary_from_function,
date_add_interval_sql,
datestrtodate_sql,
@@ -434,8 +435,13 @@ class BigQuery(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
+ exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
+ exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
+ exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
+ if e.args.get("default")
+ else f"COLLATE {self.sql(e, 'this')}",
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
@@ -632,6 +638,13 @@ class BigQuery(Dialect):
"within",
}
+ def eq_sql(self, expression: exp.EQ) -> str:
+ # Operands of = cannot be NULL in BigQuery
+ if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
+ return "NULL"
+
+ return self.binary(expression, "=")
+
def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
parent = expression.parent
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index e9d9326..30f728c 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ arg_max_or_min_no_count,
inline_array_sql,
no_pivot_sql,
rename_func,
@@ -373,8 +374,11 @@ class ClickHouse(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
+ exp.ArgMax: arg_max_or_min_no_count("argMax"),
+ exp.ArgMin: arg_max_or_min_no_count("argMin"),
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
+ exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
@@ -418,6 +422,33 @@ class ClickHouse(Dialect):
"NAMED COLLECTION",
}
+ def _any_to_has(
+ self,
+ expression: exp.EQ | exp.NEQ,
+ default: t.Callable[[t.Any], str],
+ prefix: str = "",
+ ) -> str:
+ if isinstance(expression.left, exp.Any):
+ arr = expression.left
+ this = expression.right
+ elif isinstance(expression.right, exp.Any):
+ arr = expression.right
+ this = expression.left
+ else:
+ return default(expression)
+ return prefix + self.func("has", arr.this.unnest(), this)
+
+ def eq_sql(self, expression: exp.EQ) -> str:
+ return self._any_to_has(expression, super().eq_sql)
+
+ def neq_sql(self, expression: exp.NEQ) -> str:
+ return self._any_to_has(expression, super().neq_sql, "NOT ")
+
+ def regexpilike_sql(self, expression: exp.RegexpILike) -> str:
+ # Manually add a flag to make the search case-insensitive
+ regex = self.func("CONCAT", "'(?i)'", expression.expression)
+ return f"match({self.format_args(expression.this, regex)})"
+
def datatype_sql(self, expression: exp.DataType) -> str:
# String is the standard ClickHouse type, every other variant is just an alias.
# Additionally, any supplied length parameter will be ignored.
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index bd839af..739e8d7 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -10,7 +10,7 @@ from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
-from sqlglot.time import format_time
+from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
@@ -595,6 +595,19 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
)
+def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
+ if not expression.expression:
+ return self.sql(exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP))
+ if expression.text("expression").lower() in TIMEZONES:
+ return self.sql(
+ exp.AtTimeZone(
+ this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
+ zone=expression.expression,
+ )
+ )
+ return self.function_fallback_sql(expression)
+
+
def locate_to_strposition(args: t.List) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
@@ -691,9 +704,13 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable:
_dialect = Dialect.get_or_raise(dialect)
time_format = self.format_time(expression)
if time_format and time_format not in (_dialect.TIME_FORMAT, _dialect.DATE_FORMAT):
- return self.sql(exp.cast(str_to_time_sql(self, expression), "date"))
-
- return self.sql(exp.cast(self.sql(expression, "this"), "date"))
+ return self.sql(
+ exp.cast(
+ exp.StrToTime(this=expression.this, format=expression.args["format"]),
+ "date",
+ )
+ )
+ return self.sql(exp.cast(expression.this, "date"))
return _ts_or_ds_to_date_sql
@@ -725,7 +742,9 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str:
def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str:
- bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters")))
+ bad_args = list(
+ filter(expression.args.get, ("position", "occurrence", "parameters", "modifiers"))
+ )
if bad_args:
self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}")
@@ -756,15 +775,6 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
return names
-def simplify_literal(expression: E) -> E:
- if not isinstance(expression.expression, exp.Literal):
- from sqlglot.optimizer.simplify import simplify
-
- simplify(expression.expression)
-
- return expression
-
-
def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]:
return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))
@@ -804,3 +814,21 @@ def move_insert_cte_sql(self: Generator, expression: exp.Insert) -> str:
expression = expression.copy()
expression.set("with", expression.expression.args["with"].pop())
return self.insert_sql(expression)
+
+
+def generatedasidentitycolumnconstraint_sql(
+ self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint
+) -> str:
+ start = self.sql(expression, "start") or "1"
+ increment = self.sql(expression, "increment") or "1"
+ return f"IDENTITY({start}, {increment})"
+
+
+def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]:
+ def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str:
+ if expression.args.get("count"):
+ self.unsupported(f"Only two arguments are supported in function {name}.")
+
+ return self.func(name, expression.this, expression.expression)
+
+ return _arg_max_or_min_sql
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 5b94bcb..287e03a 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
+ arg_max_or_min_no_count,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
binary_from_function,
@@ -18,9 +19,9 @@ from sqlglot.dialects.dialect import (
no_comment_column_constraint_sql,
no_properties_sql,
no_safe_divide_sql,
+ no_timestamp_sql,
pivot_column_names,
regexp_extract_sql,
- regexp_replace_sql,
rename_func,
str_position_sql,
str_to_time_sql,
@@ -172,6 +173,12 @@ class DuckDB(Dialect):
this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
),
"REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
+ "REGEXP_REPLACE": lambda args: exp.RegexpReplace(
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ replacement=seq_get(args, 2),
+ modifiers=seq_get(args, 3),
+ ),
"STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
"STRING_SPLIT": exp.Split.from_arg_list,
"STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
@@ -243,6 +250,8 @@ class DuckDB(Dialect):
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
+ exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
+ exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
exp.BitwiseXor: rename_func("XOR"),
@@ -287,7 +296,13 @@ class DuckDB(Dialect):
exp.PercentileDisc: rename_func("QUANTILE_DISC"),
exp.Properties: no_properties_sql,
exp.RegexpExtract: regexp_extract_sql,
- exp.RegexpReplace: regexp_replace_sql,
+ exp.RegexpReplace: lambda self, e: self.func(
+ "REGEXP_REPLACE",
+ e.this,
+ e.expression,
+ e.args.get("replacement"),
+ e.args.get("modifiers"),
+ ),
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
exp.SafeDivide: no_safe_divide_sql,
@@ -298,6 +313,7 @@ class DuckDB(Dialect):
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
+ exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 3f925a7..7bff553 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
+ arg_max_or_min_no_count,
create_with_partitions_sql,
format_time_lambda,
if_sql,
@@ -106,11 +107,16 @@ def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str:
sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})"
return f"({sec_diff}){factor}" if factor else sec_diff
- sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
+ months_between = unit in DIFF_MONTH_SWITCH
+ sql_func = "MONTHS_BETWEEN" if months_between else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
multiplier_sql = f" / {multiplier}" if multiplier > 1 else ""
diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})"
+ if months_between:
+ # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part
+ diff_sql = f"CAST({diff_sql} AS INT)"
+
return f"{diff_sql}{multiplier_sql}"
@@ -426,6 +432,8 @@ class Hive(Dialect):
exp.Property: _property_sql,
exp.AnyValue: rename_func("FIRST"),
exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
+ exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 59a0a2a..2185a85 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -21,7 +21,6 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
parse_date_delta_with_interval,
rename_func,
- simplify_literal,
strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
@@ -689,6 +688,8 @@ class MySQL(Dialect):
LIMIT_FETCH = "LIMIT"
+ LIMIT_ONLY_LITERALS = True
+
# MySQL doesn't support many datatypes in cast.
# https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast
CAST_MAPPING = {
@@ -712,16 +713,6 @@ class MySQL(Dialect):
result = f"{result} UNSIGNED"
return result
- def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
- # MySQL requires simple literal values for its LIMIT clause.
- expression = simplify_literal(expression.copy())
- return super().limit_sql(expression, top=top)
-
- def offset_sql(self, expression: exp.Offset) -> str:
- # MySQL requires simple literal values for its OFFSET clause.
- expression = simplify_literal(expression.copy())
- return super().offset_sql(expression)
-
def xor_sql(self, expression: exp.Xor) -> str:
if expression.expressions:
return self.expressions(expression, sep=" XOR ")
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c435309..086b278 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -20,7 +20,6 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
parse_timestamp_trunc,
rename_func,
- simplify_literal,
str_position_sql,
struct_extract_sql,
timestamptrunc_sql,
@@ -49,7 +48,7 @@ def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | ex
this = self.sql(expression, "this")
unit = expression.args.get("unit")
- expression = simplify_literal(expression).expression
+ expression = self._simplify_unless_literal(expression.expression)
if not isinstance(expression, exp.Literal):
self.unsupported("Cannot add non literal")
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 88525a2..aac368c 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
no_pivot_sql,
no_safe_divide_sql,
+ no_timestamp_sql,
regexp_extract_sql,
rename_func,
right_to_substring_sql,
@@ -69,9 +70,10 @@ def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str:
if expression.parent:
for schema in expression.parent.find_all(exp.Schema):
- if isinstance(schema.parent, exp.Property):
+ column_defs = schema.find_all(exp.ColumnDef)
+ if column_defs and isinstance(schema.parent, exp.Property):
expression = expression.copy()
- expression.expressions.extend(schema.expressions)
+ expression.expressions.extend(column_defs)
return self.schema_sql(expression)
@@ -252,6 +254,7 @@ class Presto(Dialect):
TZ_TO_WITH_TIME_ZONE = True
NVL2_SUPPORTED = False
STRUCT_DELIMITER = ("(", ")")
+ LIMIT_ONLY_LITERALS = True
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@@ -277,6 +280,8 @@ class Presto(Dialect):
exp.AnyValue: rename_func("ARBITRARY"),
exp.ApproxDistinct: _approx_distinct_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
+ exp.ArgMax: rename_func("MAX_BY"),
+ exp.ArgMin: rename_func("MIN_BY"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
@@ -348,6 +353,7 @@ class Presto(Dialect):
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
+ exp.Timestamp: no_timestamp_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
@@ -367,7 +373,6 @@ class Presto(Dialect):
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
- exp.Timestamp: transforms.preprocess([transforms.timestamp_to_cast]),
exp.Xor: bool_xor_sql,
}
@@ -418,3 +423,15 @@ class Presto(Dialect):
self.sql(expression, "offset"),
self.sql(limit),
]
+
+ def create_sql(self, expression: exp.Create) -> str:
+ """
+ Presto doesn't support CREATE VIEW with expressions (ex: `CREATE VIEW x (cola)` then `(cola)` is the expression),
+ so we need to remove them
+ """
+ kind = expression.args["kind"]
+ schema = expression.this
+ if kind == "VIEW" and schema.expressions:
+ expression = expression.copy()
+ expression.this.set("expressions", None)
+ return super().create_sql(expression)
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 04e78a5..df70aa7 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -6,6 +6,7 @@ from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
concat_to_dpipe_sql,
concat_ws_to_dpipe_sql,
+ generatedasidentitycolumnconstraint_sql,
rename_func,
ts_or_ds_to_date_sql,
)
@@ -171,8 +172,10 @@ class Redshift(Postgres):
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
+ exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
+ exp.ParseJSON: rename_func("JSON_PARSE"),
exp.SafeConcat: concat_to_dpipe_sql,
exp.Select: transforms.preprocess(
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index fc3e0fa..07be65b 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -262,6 +262,7 @@ class Snowflake(Dialect):
),
"DATEDIFF": _parse_datediff,
"DIV0": _div0_to_if,
+ "FLATTEN": exp.Explode.from_arg_list,
"IFF": exp.If.from_arg_list,
"LISTAGG": exp.GroupConcat.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
@@ -308,6 +309,7 @@ class Snowflake(Dialect):
expressions=self._parse_csv(self._parse_id_var),
unset=True,
),
+ "SWAP": lambda self: self._parse_alter_table_swap(),
}
STATEMENT_PARSERS = {
@@ -325,6 +327,22 @@ class Snowflake(Dialect):
TokenType.MOD,
TokenType.SLASH,
}
+ FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"]
+
+ def _parse_lateral(self) -> t.Optional[exp.Lateral]:
+ lateral = super()._parse_lateral()
+ if not lateral:
+ return lateral
+
+ if isinstance(lateral.this, exp.Explode):
+ table_alias = lateral.args.get("alias")
+ columns = [exp.to_identifier(col) for col in self.FLATTEN_COLUMNS]
+ if table_alias and not table_alias.args.get("columns"):
+ table_alias.set("columns", columns)
+ elif not table_alias:
+ exp.alias_(lateral, "_flattened", table=columns, copy=False)
+
+ return lateral
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
@@ -389,6 +407,10 @@ class Snowflake(Dialect):
return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind)
+ def _parse_alter_table_swap(self) -> exp.SwapTable:
+ self._match_text_seq("WITH")
+ return self.expression(exp.SwapTable, this=self._parse_table(schema=True))
+
class Tokenizer(tokens.Tokenizer):
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
@@ -438,6 +460,8 @@ class Snowflake(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.ArgMax: rename_func("MAX_BY"),
+ exp.ArgMin: rename_func("MIN_BY"),
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
@@ -451,7 +475,10 @@ class Snowflake(Dialect):
),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.Explode: rename_func("FLATTEN"),
exp.Extract: rename_func("DATE_PART"),
exp.GenerateSeries: lambda self, e: self.func(
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
@@ -520,6 +547,12 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def log_sql(self, expression: exp.Log) -> str:
+ if not expression.expression:
+ return self.func("LN", expression.this)
+
+ return super().log_sql(expression)
+
def unnest_sql(self, expression: exp.Unnest) -> str:
selects = ["value"]
unnest_alias = expression.args.get("alias")
@@ -596,3 +629,7 @@ class Snowflake(Dialect):
increment = expression.args.get("increment")
increment = f" INCREMENT {increment}" if increment else ""
return f"AUTOINCREMENT{start}{increment}"
+
+ def swaptable_sql(self, expression: exp.SwapTable) -> str:
+ this = self.sql(expression, "this")
+ return f"SWAP WITH {this}"
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index b9e925a..152afa6 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least
+from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
from sqlglot.tokens import TokenType
@@ -150,6 +150,7 @@ class Teradata(Dialect):
return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
class Generator(generator.Generator):
+ LIMIT_IS_TOP = True
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
@@ -168,6 +169,8 @@ class Teradata(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.ArgMax: rename_func("MAX_BY"),
+ exp.ArgMin: rename_func("MIN_BY"),
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 69adb45..867e4e4 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
any_value_to_max_sql,
+ generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
move_insert_cte_sql,
@@ -603,6 +604,7 @@ class TSQL(Dialect):
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.DOUBLE: "FLOAT",
exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.TEXT: "VARCHAR(MAX)",
exp.DataType.Type.TIMESTAMP: "DATETIME2",
exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
@@ -617,6 +619,7 @@ class TSQL(Dialect):
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
+ exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.Insert: move_insert_cte_sql,
@@ -778,11 +781,3 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True, sep=" ")
return f"CONSTRAINT {this} {expressions}"
-
- # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
- def generatedasidentitycolumnconstraint_sql(
- self, expression: exp.GeneratedAsIdentityColumnConstraint
- ) -> str:
- start = self.sql(expression, "start") or "1"
- increment = self.sql(expression, "increment") or "1"
- return f"IDENTITY({start}, {increment})"
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index b94b1e1..5b012b1 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -23,7 +23,7 @@ from enum import auto
from functools import reduce
from sqlglot._typing import E
-from sqlglot.errors import ParseError
+from sqlglot.errors import ErrorLevel, ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
@@ -120,14 +120,14 @@ class Expression(metaclass=_Expression):
return hash((self.__class__, self.hashable_args))
@property
- def this(self):
+ def this(self) -> t.Any:
"""
Retrieves the argument with key "this".
"""
return self.args.get("this")
@property
- def expression(self):
+ def expression(self) -> t.Any:
"""
Retrieves the argument with key "expression".
"""
@@ -1235,6 +1235,10 @@ class RenameTable(Expression):
pass
+class SwapTable(Expression):
+ pass
+
+
class Comment(Expression):
arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
@@ -1979,7 +1983,7 @@ class ChecksumProperty(Property):
class CollateProperty(Property):
- arg_types = {"this": True}
+ arg_types = {"this": True, "default": False}
class CopyGrantsProperty(Property):
@@ -2607,11 +2611,11 @@ class Union(Subqueryable):
return self.this.unnest().selects
@property
- def left(self):
+ def left(self) -> Expression:
return self.this
@property
- def right(self):
+ def right(self) -> Expression:
return self.expression
@@ -3700,7 +3704,9 @@ class DataType(Expression):
return DataType(this=DataType.Type.UNKNOWN, **kwargs)
try:
- data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+ data_type_exp = parse_one(
+ dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE
+ )
except ParseError:
if udt:
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
@@ -3804,11 +3810,11 @@ class Binary(Condition):
arg_types = {"this": True, "expression": True}
@property
- def left(self):
+ def left(self) -> Expression:
return self.this
@property
- def right(self):
+ def right(self) -> Expression:
return self.expression
@@ -4063,10 +4069,25 @@ class TimeUnit(Expression):
arg_types = {"unit": False}
+ UNABBREVIATED_UNIT_NAME = {
+ "d": "day",
+ "h": "hour",
+ "m": "minute",
+ "ms": "millisecond",
+ "ns": "nanosecond",
+ "q": "quarter",
+ "s": "second",
+ "us": "microsecond",
+ "w": "week",
+ "y": "year",
+ }
+
+ VAR_LIKE = (Column, Literal, Var)
+
def __init__(self, **args):
unit = args.get("unit")
- if isinstance(unit, (Column, Literal)):
- args["unit"] = Var(this=unit.name)
+ if isinstance(unit, self.VAR_LIKE):
+ args["unit"] = Var(this=self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name)
elif isinstance(unit, Week):
unit.set("this", Var(this=unit.this.name))
@@ -4168,6 +4189,24 @@ class Abs(Func):
pass
+class ArgMax(AggFunc):
+ arg_types = {"this": True, "expression": True, "count": False}
+ _sql_names = ["ARG_MAX", "ARGMAX", "MAX_BY"]
+
+
+class ArgMin(AggFunc):
+ arg_types = {"this": True, "expression": True, "count": False}
+ _sql_names = ["ARG_MIN", "ARGMIN", "MIN_BY"]
+
+
+class ApproxTopK(AggFunc):
+ arg_types = {"this": True, "expression": False, "counters": False}
+
+
+class Flatten(Func):
+ pass
+
+
# https://spark.apache.org/docs/latest/api/sql/index.html#transform
class Transform(Func):
arg_types = {"this": True, "expression": True}
@@ -4540,8 +4579,10 @@ class Exp(Func):
pass
+# https://docs.snowflake.com/en/sql-reference/functions/flatten
class Explode(Func):
- pass
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
class ExplodeOuter(Explode):
@@ -4698,6 +4739,8 @@ class JSONArrayContains(Binary, Predicate, Func):
class ParseJSON(Func):
# BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE
_sql_names = ["PARSE_JSON", "JSON_PARSE"]
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
class Least(Func):
@@ -4758,6 +4801,16 @@ class Lower(Func):
class Map(Func):
arg_types = {"keys": False, "values": False}
+ @property
+ def keys(self) -> t.List[Expression]:
+ keys = self.args.get("keys")
+ return keys.expressions if keys else []
+
+ @property
+ def values(self) -> t.List[Expression]:
+ values = self.args.get("values")
+ return values.expressions if values else []
+
class MapFromEntries(Func):
pass
@@ -4870,6 +4923,7 @@ class RegexpReplace(Func):
"position": False,
"occurrence": False,
"parameters": False,
+ "modifiers": False,
}
@@ -4877,7 +4931,7 @@ class RegexpLike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}
-class RegexpILike(Func):
+class RegexpILike(Binary, Func):
arg_types = {"this": True, "expression": True, "flag": False}
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index b7e26bb..0d6778a 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -11,6 +11,9 @@ from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer, TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
logger = logging.getLogger("sqlglot")
@@ -141,6 +144,9 @@ class Generator:
# Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
+ # Whether or not limit and fetch allows expresions or just limits
+ LIMIT_ONLY_LITERALS = False
+
# Whether or not a table is allowed to be renamed with a db
RENAME_TABLE_WITH_DB = True
@@ -341,6 +347,12 @@ class Generator:
exp.With,
)
+ # Expressions that should not have their comments generated in maybe_comment
+ EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = (
+ exp.Binary,
+ exp.Union,
+ )
+
# Expressions that can remain unwrapped when appearing in the context of an INTERVAL
UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = (
exp.Column,
@@ -501,7 +513,7 @@ class Generator:
else None
)
- if not comments or isinstance(expression, exp.Binary):
+ if not comments or isinstance(expression, self.EXCLUDE_COMMENTS):
return sql
comments_sql = " ".join(
@@ -879,6 +891,10 @@ class Generator:
alias = self.sql(expression, "this")
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
+
+ if not alias and not self.UNNEST_COLUMN_ONLY:
+ alias = "_t"
+
return f"{alias}{columns}"
def bitstring_sql(self, expression: exp.BitString) -> str:
@@ -1611,9 +1627,6 @@ class Generator:
def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
- if isinstance(expression.this, exp.Subquery):
- return f"LATERAL {this}"
-
if expression.args.get("view"):
alias = expression.args["alias"]
columns = self.expressions(alias, key="columns", flat=True)
@@ -1629,18 +1642,19 @@ class Generator:
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
this = self.sql(expression, "this")
args = ", ".join(
- sql
- for sql in (
- self.sql(expression, "offset"),
- self.sql(expression, "expression"),
- )
- if sql
+ self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e)
+ for e in (expression.args.get(k) for k in ("offset", "expression"))
+ if e
)
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}"
def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
- return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
+ expression = expression.expression
+ expression = (
+ self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression
+ )
+ return f"{this}{self.seg('OFFSET')} {self.sql(expression)}"
def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
@@ -1895,12 +1909,13 @@ class Generator:
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
- this = f"{this} " if this else ""
sql = self.schema_columns_sql(expression)
- return f"{this}{sql}"
+ return f"{this} {sql}" if this and sql else this or sql
def schema_columns_sql(self, expression: exp.Schema) -> str:
- return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
+ if expression.expressions:
+ return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
+ return ""
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
@@ -2708,8 +2723,8 @@ class Generator:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
- def set_operation(self, expression: exp.Expression, op: str) -> str:
- this = self.sql(expression, "this")
+ def set_operation(self, expression: exp.Union, op: str) -> str:
+ this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments)
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
@@ -2912,6 +2927,14 @@ class Generator:
parameters = self.sql(expression, "params_struct")
return self.func("PREDICT", model, table, parameters or None)
+ def _simplify_unless_literal(self, expression: E) -> E:
+ if not isinstance(expression, exp.Literal):
+ from sqlglot.optimizer.simplify import simplify
+
+ expression = simplify(expression.copy())
+
+ return expression
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py
index 113458f..011a6b8 100644
--- a/sqlglot/lineage.py
+++ b/sqlglot/lineage.py
@@ -112,17 +112,34 @@ def lineage(
column
if isinstance(column, int)
else next(
- i
- for i, select in enumerate(scope.expression.selects)
- if select.alias_or_name == column
+ (
+ i
+ for i, select in enumerate(scope.expression.selects)
+ if select.alias_or_name == column or select.is_star
+ ),
+ -1, # mypy will not allow a None here, but a negative index should never be returned
)
)
+ if index == -1:
+ raise ValueError(f"Could not find {column} in {scope.expression}")
+
for s in scope.union_scopes:
to_node(index, scope=s, upstream=upstream)
return upstream
+ subquery = select.unalias()
+
+ if isinstance(subquery, exp.Subquery):
+ upstream = upstream or Node(name="SUBQUERY", source=scope.expression, expression=select)
+ scope = t.cast(Scope, build_scope(subquery.unnest()))
+
+ for select in subquery.named_selects:
+ to_node(select, scope=scope, upstream=upstream)
+
+ return upstream
+
if isinstance(scope.expression, exp.Select):
# For better ergonomics in our node labels, replace the full select with
# a version that has only the column we care about.
@@ -142,8 +159,19 @@ def lineage(
if upstream:
upstream.downstream.append(node)
+ # if the select is a star add all scope sources as downstreams
+ if select.is_star:
+ for source in scope.sources.values():
+ node.downstream.append(Node(name=select.sql(), source=source, expression=source))
+
# Find all columns that went into creating this one to list their lineage nodes.
- for c in set(select.find_all(exp.Column)):
+ source_columns = set(select.find_all(exp.Column))
+
+ # If the source is a UDTF find columns used in the UTDF to generate the table
+ if isinstance(source, exp.UDTF):
+ source_columns |= set(source.find_all(exp.Column))
+
+ for c in source_columns:
table = c.table
source = scope.sources.get(table)
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 17af6ac..69d4567 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -6,7 +6,7 @@ import typing as t
from sqlglot import exp
from sqlglot._typing import E
-from sqlglot.helper import ensure_list, subclasses
+from sqlglot.helper import ensure_list, seq_get, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Bracket: lambda self, e: self._annotate_bracket(e),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
@@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
@@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, datatype)
return expression
+
+ def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
+ self._annotate_args(expression)
+
+ bracket_arg = expression.expressions[0]
+ this = expression.this
+
+ if isinstance(bracket_arg, exp.Slice):
+ self._set_type(expression, this.type)
+ elif this.type.is_type(exp.DataType.Type.ARRAY):
+ contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
+ self._set_type(expression, contained_type)
+ elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
+ index = this.keys.index(bracket_arg)
+ value = seq_get(this.values, index)
+ value_type = value.type if value else exp.DataType.Type.UNKNOWN
+ self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
+ else:
+ self._set_type(expression, exp.DataType.Type.UNKNOWN)
+
+ return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index ec3b3af..fc5c348 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -69,7 +69,11 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
- elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
+ elif isinstance(expression, (exp.Where, exp.Having)) or (
+ # We can't replace num in CASE x WHEN num ..., because it's not the full predicate
+ isinstance(expression, exp.If)
+ and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
+ ):
_replace_int_predicate(expression.this)
return expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 849643c..30de75b 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -70,6 +70,7 @@ def simplify(expression, constant_propagation=False):
node = uniq_sort(node, generate, root)
node = absorb_and_eliminate(node, root)
node = simplify_concat(node)
+ node = simplify_conditionals(node)
if constant_propagation:
node = propagate_constants(node, root)
@@ -477,9 +478,11 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
return expression
if l.__class__ in INVERSE_DATE_OPS:
+ l = t.cast(exp.IntervalOp, l)
a = l.this
b = l.interval()
else:
+ l = t.cast(exp.Binary, l)
a, b = l.left, l.right
if not a_predicate(a) and b_predicate(b):
@@ -695,6 +698,32 @@ def simplify_concat(expression):
return concat_type(expressions=new_args)
+def simplify_conditionals(expression):
+ """Simplifies expressions like IF, CASE if their condition is statically known."""
+ if isinstance(expression, exp.Case):
+ this = expression.this
+ for case in expression.args["ifs"]:
+ cond = case.this
+ if this:
+ # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
+ cond = cond.replace(this.pop().eq(cond))
+
+ if always_true(cond):
+ return case.args["true"]
+
+ if always_false(cond):
+ case.pop()
+ if not expression.args["ifs"]:
+ return expression.args.get("default") or exp.null()
+ elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
+ if always_true(expression.this):
+ return expression.args["true"]
+ if always_false(expression.this):
+ return expression.args.get("false") or exp.null()
+
+ return expression
+
+
DateRange = t.Tuple[datetime.date, datetime.date]
@@ -786,6 +815,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
else:
return expression
+ l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
date = extract_date(r)
@@ -798,6 +828,7 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
rs = expression.expressions
if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
+ l = t.cast(exp.DateTrunc, l)
unit = l.unit.name.lower()
ranges = []
@@ -852,6 +883,10 @@ def always_true(expression):
)
+def always_false(expression):
+ return is_false(expression) or is_null(expression)
+
+
def is_complement(a, b):
return isinstance(b, exp.Not) and b.this == a
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 8de76ca..b7f91ab 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -313,6 +313,7 @@ class Parser(metaclass=_Parser):
TokenType.UNIQUE,
TokenType.UNPIVOT,
TokenType.UPDATE,
+ TokenType.USE,
TokenType.VOLATILE,
TokenType.WINDOW,
*CREATABLES,
@@ -629,11 +630,14 @@ class Parser(metaclass=_Parser):
"ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty),
"AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty),
"BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(),
- "CHARACTER SET": lambda self: self._parse_character_set(),
+ "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs),
+ "CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs),
"CHECKSUM": lambda self: self._parse_checksum(),
"CLUSTER BY": lambda self: self._parse_cluster(),
"CLUSTERED": lambda self: self._parse_clustered_by(),
- "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty),
+ "COLLATE": lambda self, **kwargs: self._parse_property_assignment(
+ exp.CollateProperty, **kwargs
+ ),
"COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
"COPY": lambda self: self._parse_copy_property(),
"DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs),
@@ -1443,8 +1447,8 @@ class Parser(metaclass=_Parser):
if self._match_texts(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.text.upper()](self)
- if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
- return self._parse_character_set(default=True)
+ if self._match(TokenType.DEFAULT) and self._match_texts(self.PROPERTY_PARSERS):
+ return self.PROPERTY_PARSERS[self._prev.text.upper()](self, default=True)
if self._match_text_seq("COMPOUND", "SORTKEY"):
return self._parse_sortkey(compound=True)
@@ -1480,10 +1484,10 @@ class Parser(metaclass=_Parser):
else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
- def _parse_property_assignment(self, exp_class: t.Type[E]) -> E:
+ def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
- return self.expression(exp_class, this=self._parse_field())
+ return self.expression(exp_class, this=self._parse_field(), **kwargs)
def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]:
properties = []
@@ -2426,9 +2430,9 @@ class Parser(metaclass=_Parser):
table_alias: t.Optional[exp.TableAlias] = self.expression(
exp.TableAlias, this=table, columns=columns
)
- elif isinstance(this, exp.Subquery) and this.alias:
- # Ensures parity between the Subquery's and the Lateral's "alias" args
- table_alias = this.args["alias"].copy()
+ elif isinstance(this, (exp.Subquery, exp.Unnest)) and this.alias:
+ # We move the alias from the lateral's child node to the lateral itself
+ table_alias = this.args["alias"].pop()
else:
table_alias = self._parse_table_alias()
@@ -2952,6 +2956,7 @@ class Parser(metaclass=_Parser):
cube = None
totals = None
+ index = self._index
with_ = self._match(TokenType.WITH)
if self._match(TokenType.ROLLUP):
rollup = with_ or self._parse_wrapped_csv(self._parse_column)
@@ -2966,6 +2971,8 @@ class Parser(metaclass=_Parser):
elements["totals"] = True # type: ignore
if not (grouping_sets or rollup or cube or totals):
+ if with_:
+ self._retreat(index)
break
return self.expression(exp.Group, **elements) # type: ignore
@@ -3157,6 +3164,7 @@ class Parser(metaclass=_Parser):
return self.expression(
expression,
+ comments=self._prev.comments,
this=this,
distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
by_name=self._match_text_seq("BY", "NAME"),
@@ -3619,6 +3627,32 @@ class Parser(metaclass=_Parser):
anonymous: bool = False,
optional_parens: bool = True,
) -> t.Optional[exp.Expression]:
+ # This allows us to also parse {fn <function>} syntax (Snowflake, MySQL support this)
+ # See: https://community.snowflake.com/s/article/SQL-Escape-Sequences
+ fn_syntax = False
+ if (
+ self._match(TokenType.L_BRACE, advance=False)
+ and self._next
+ and self._next.text.upper() == "FN"
+ ):
+ self._advance(2)
+ fn_syntax = True
+
+ func = self._parse_function_call(
+ functions=functions, anonymous=anonymous, optional_parens=optional_parens
+ )
+
+ if fn_syntax:
+ self._match(TokenType.R_BRACE)
+
+ return func
+
+ def _parse_function_call(
+ self,
+ functions: t.Optional[t.Dict[str, t.Callable]] = None,
+ anonymous: bool = False,
+ optional_parens: bool = True,
+ ) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@@ -3856,6 +3890,10 @@ class Parser(metaclass=_Parser):
if not identity:
this.set("expression", self._parse_bitwise())
+ elif not this.args.get("start") and self._match(TokenType.NUMBER, advance=False):
+ args = self._parse_csv(self._parse_bitwise)
+ this.set("start", seq_get(args, 0))
+ this.set("increment", seq_get(args, 1))
self._match_r_paren()
@@ -4039,6 +4077,11 @@ class Parser(metaclass=_Parser):
)
)
+ if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
+ self.raise_error("Expected ]")
+ elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
+ self.raise_error("Expected }")
+
# https://duckdb.org/docs/sql/data_types/struct.html#creating-structs
if bracket_kind == TokenType.L_BRACE:
this = self.expression(exp.Struct, expressions=expressions)
@@ -4048,11 +4091,6 @@ class Parser(metaclass=_Parser):
expressions = apply_index_offset(this, expressions, -self.INDEX_OFFSET)
this = self.expression(exp.Bracket, this=this, expressions=expressions)
- if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET:
- self.raise_error("Expected ]")
- elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE:
- self.raise_error("Expected }")
-
self._add_comments(this)
return self._parse_bracket(this)
diff --git a/sqlglot/time.py b/sqlglot/time.py
index 5f0848e..c286ec1 100644
--- a/sqlglot/time.py
+++ b/sqlglot/time.py
@@ -54,3 +54,606 @@ def format_time(
chunks.append(chars)
return "".join(mapping.get(chars, chars) for chars in chunks)
+
+
+TIMEZONES = {
+ tz.lower()
+ for tz in (
+ "Africa/Abidjan",
+ "Africa/Accra",
+ "Africa/Addis_Ababa",
+ "Africa/Algiers",
+ "Africa/Asmara",
+ "Africa/Asmera",
+ "Africa/Bamako",
+ "Africa/Bangui",
+ "Africa/Banjul",
+ "Africa/Bissau",
+ "Africa/Blantyre",
+ "Africa/Brazzaville",
+ "Africa/Bujumbura",
+ "Africa/Cairo",
+ "Africa/Casablanca",
+ "Africa/Ceuta",
+ "Africa/Conakry",
+ "Africa/Dakar",
+ "Africa/Dar_es_Salaam",
+ "Africa/Djibouti",
+ "Africa/Douala",
+ "Africa/El_Aaiun",
+ "Africa/Freetown",
+ "Africa/Gaborone",
+ "Africa/Harare",
+ "Africa/Johannesburg",
+ "Africa/Juba",
+ "Africa/Kampala",
+ "Africa/Khartoum",
+ "Africa/Kigali",
+ "Africa/Kinshasa",
+ "Africa/Lagos",
+ "Africa/Libreville",
+ "Africa/Lome",
+ "Africa/Luanda",
+ "Africa/Lubumbashi",
+ "Africa/Lusaka",
+ "Africa/Malabo",
+ "Africa/Maputo",
+ "Africa/Maseru",
+ "Africa/Mbabane",
+ "Africa/Mogadishu",
+ "Africa/Monrovia",
+ "Africa/Nairobi",
+ "Africa/Ndjamena",
+ "Africa/Niamey",
+ "Africa/Nouakchott",
+ "Africa/Ouagadougou",
+ "Africa/Porto-Novo",
+ "Africa/Sao_Tome",
+ "Africa/Timbuktu",
+ "Africa/Tripoli",
+ "Africa/Tunis",
+ "Africa/Windhoek",
+ "America/Adak",
+ "America/Anchorage",
+ "America/Anguilla",
+ "America/Antigua",
+ "America/Araguaina",
+ "America/Argentina/Buenos_Aires",
+ "America/Argentina/Catamarca",
+ "America/Argentina/ComodRivadavia",
+ "America/Argentina/Cordoba",
+ "America/Argentina/Jujuy",
+ "America/Argentina/La_Rioja",
+ "America/Argentina/Mendoza",
+ "America/Argentina/Rio_Gallegos",
+ "America/Argentina/Salta",
+ "America/Argentina/San_Juan",
+ "America/Argentina/San_Luis",
+ "America/Argentina/Tucuman",
+ "America/Argentina/Ushuaia",
+ "America/Aruba",
+ "America/Asuncion",
+ "America/Atikokan",
+ "America/Atka",
+ "America/Bahia",
+ "America/Bahia_Banderas",
+ "America/Barbados",
+ "America/Belem",
+ "America/Belize",
+ "America/Blanc-Sablon",
+ "America/Boa_Vista",
+ "America/Bogota",
+ "America/Boise",
+ "America/Buenos_Aires",
+ "America/Cambridge_Bay",
+ "America/Campo_Grande",
+ "America/Cancun",
+ "America/Caracas",
+ "America/Catamarca",
+ "America/Cayenne",
+ "America/Cayman",
+ "America/Chicago",
+ "America/Chihuahua",
+ "America/Ciudad_Juarez",
+ "America/Coral_Harbour",
+ "America/Cordoba",
+ "America/Costa_Rica",
+ "America/Creston",
+ "America/Cuiaba",
+ "America/Curacao",
+ "America/Danmarkshavn",
+ "America/Dawson",
+ "America/Dawson_Creek",
+ "America/Denver",
+ "America/Detroit",
+ "America/Dominica",
+ "America/Edmonton",
+ "America/Eirunepe",
+ "America/El_Salvador",
+ "America/Ensenada",
+ "America/Fort_Nelson",
+ "America/Fort_Wayne",
+ "America/Fortaleza",
+ "America/Glace_Bay",
+ "America/Godthab",
+ "America/Goose_Bay",
+ "America/Grand_Turk",
+ "America/Grenada",
+ "America/Guadeloupe",
+ "America/Guatemala",
+ "America/Guayaquil",
+ "America/Guyana",
+ "America/Halifax",
+ "America/Havana",
+ "America/Hermosillo",
+ "America/Indiana/Indianapolis",
+ "America/Indiana/Knox",
+ "America/Indiana/Marengo",
+ "America/Indiana/Petersburg",
+ "America/Indiana/Tell_City",
+ "America/Indiana/Vevay",
+ "America/Indiana/Vincennes",
+ "America/Indiana/Winamac",
+ "America/Indianapolis",
+ "America/Inuvik",
+ "America/Iqaluit",
+ "America/Jamaica",
+ "America/Jujuy",
+ "America/Juneau",
+ "America/Kentucky/Louisville",
+ "America/Kentucky/Monticello",
+ "America/Knox_IN",
+ "America/Kralendijk",
+ "America/La_Paz",
+ "America/Lima",
+ "America/Los_Angeles",
+ "America/Louisville",
+ "America/Lower_Princes",
+ "America/Maceio",
+ "America/Managua",
+ "America/Manaus",
+ "America/Marigot",
+ "America/Martinique",
+ "America/Matamoros",
+ "America/Mazatlan",
+ "America/Mendoza",
+ "America/Menominee",
+ "America/Merida",
+ "America/Metlakatla",
+ "America/Mexico_City",
+ "America/Miquelon",
+ "America/Moncton",
+ "America/Monterrey",
+ "America/Montevideo",
+ "America/Montreal",
+ "America/Montserrat",
+ "America/Nassau",
+ "America/New_York",
+ "America/Nipigon",
+ "America/Nome",
+ "America/Noronha",
+ "America/North_Dakota/Beulah",
+ "America/North_Dakota/Center",
+ "America/North_Dakota/New_Salem",
+ "America/Nuuk",
+ "America/Ojinaga",
+ "America/Panama",
+ "America/Pangnirtung",
+ "America/Paramaribo",
+ "America/Phoenix",
+ "America/Port-au-Prince",
+ "America/Port_of_Spain",
+ "America/Porto_Acre",
+ "America/Porto_Velho",
+ "America/Puerto_Rico",
+ "America/Punta_Arenas",
+ "America/Rainy_River",
+ "America/Rankin_Inlet",
+ "America/Recife",
+ "America/Regina",
+ "America/Resolute",
+ "America/Rio_Branco",
+ "America/Rosario",
+ "America/Santa_Isabel",
+ "America/Santarem",
+ "America/Santiago",
+ "America/Santo_Domingo",
+ "America/Sao_Paulo",
+ "America/Scoresbysund",
+ "America/Shiprock",
+ "America/Sitka",
+ "America/St_Barthelemy",
+ "America/St_Johns",
+ "America/St_Kitts",
+ "America/St_Lucia",
+ "America/St_Thomas",
+ "America/St_Vincent",
+ "America/Swift_Current",
+ "America/Tegucigalpa",
+ "America/Thule",
+ "America/Thunder_Bay",
+ "America/Tijuana",
+ "America/Toronto",
+ "America/Tortola",
+ "America/Vancouver",
+ "America/Virgin",
+ "America/Whitehorse",
+ "America/Winnipeg",
+ "America/Yakutat",
+ "America/Yellowknife",
+ "Antarctica/Casey",
+ "Antarctica/Davis",
+ "Antarctica/DumontDUrville",
+ "Antarctica/Macquarie",
+ "Antarctica/Mawson",
+ "Antarctica/McMurdo",
+ "Antarctica/Palmer",
+ "Antarctica/Rothera",
+ "Antarctica/South_Pole",
+ "Antarctica/Syowa",
+ "Antarctica/Troll",
+ "Antarctica/Vostok",
+ "Arctic/Longyearbyen",
+ "Asia/Aden",
+ "Asia/Almaty",
+ "Asia/Amman",
+ "Asia/Anadyr",
+ "Asia/Aqtau",
+ "Asia/Aqtobe",
+ "Asia/Ashgabat",
+ "Asia/Ashkhabad",
+ "Asia/Atyrau",
+ "Asia/Baghdad",
+ "Asia/Bahrain",
+ "Asia/Baku",
+ "Asia/Bangkok",
+ "Asia/Barnaul",
+ "Asia/Beirut",
+ "Asia/Bishkek",
+ "Asia/Brunei",
+ "Asia/Calcutta",
+ "Asia/Chita",
+ "Asia/Choibalsan",
+ "Asia/Chongqing",
+ "Asia/Chungking",
+ "Asia/Colombo",
+ "Asia/Dacca",
+ "Asia/Damascus",
+ "Asia/Dhaka",
+ "Asia/Dili",
+ "Asia/Dubai",
+ "Asia/Dushanbe",
+ "Asia/Famagusta",
+ "Asia/Gaza",
+ "Asia/Harbin",
+ "Asia/Hebron",
+ "Asia/Ho_Chi_Minh",
+ "Asia/Hong_Kong",
+ "Asia/Hovd",
+ "Asia/Irkutsk",
+ "Asia/Istanbul",
+ "Asia/Jakarta",
+ "Asia/Jayapura",
+ "Asia/Jerusalem",
+ "Asia/Kabul",
+ "Asia/Kamchatka",
+ "Asia/Karachi",
+ "Asia/Kashgar",
+ "Asia/Kathmandu",
+ "Asia/Katmandu",
+ "Asia/Khandyga",
+ "Asia/Kolkata",
+ "Asia/Krasnoyarsk",
+ "Asia/Kuala_Lumpur",
+ "Asia/Kuching",
+ "Asia/Kuwait",
+ "Asia/Macao",
+ "Asia/Macau",
+ "Asia/Magadan",
+ "Asia/Makassar",
+ "Asia/Manila",
+ "Asia/Muscat",
+ "Asia/Nicosia",
+ "Asia/Novokuznetsk",
+ "Asia/Novosibirsk",
+ "Asia/Omsk",
+ "Asia/Oral",
+ "Asia/Phnom_Penh",
+ "Asia/Pontianak",
+ "Asia/Pyongyang",
+ "Asia/Qatar",
+ "Asia/Qostanay",
+ "Asia/Qyzylorda",
+ "Asia/Rangoon",
+ "Asia/Riyadh",
+ "Asia/Saigon",
+ "Asia/Sakhalin",
+ "Asia/Samarkand",
+ "Asia/Seoul",
+ "Asia/Shanghai",
+ "Asia/Singapore",
+ "Asia/Srednekolymsk",
+ "Asia/Taipei",
+ "Asia/Tashkent",
+ "Asia/Tbilisi",
+ "Asia/Tehran",
+ "Asia/Tel_Aviv",
+ "Asia/Thimbu",
+ "Asia/Thimphu",
+ "Asia/Tokyo",
+ "Asia/Tomsk",
+ "Asia/Ujung_Pandang",
+ "Asia/Ulaanbaatar",
+ "Asia/Ulan_Bator",
+ "Asia/Urumqi",
+ "Asia/Ust-Nera",
+ "Asia/Vientiane",
+ "Asia/Vladivostok",
+ "Asia/Yakutsk",
+ "Asia/Yangon",
+ "Asia/Yekaterinburg",
+ "Asia/Yerevan",
+ "Atlantic/Azores",
+ "Atlantic/Bermuda",
+ "Atlantic/Canary",
+ "Atlantic/Cape_Verde",
+ "Atlantic/Faeroe",
+ "Atlantic/Faroe",
+ "Atlantic/Jan_Mayen",
+ "Atlantic/Madeira",
+ "Atlantic/Reykjavik",
+ "Atlantic/South_Georgia",
+ "Atlantic/St_Helena",
+ "Atlantic/Stanley",
+ "Australia/ACT",
+ "Australia/Adelaide",
+ "Australia/Brisbane",
+ "Australia/Broken_Hill",
+ "Australia/Canberra",
+ "Australia/Currie",
+ "Australia/Darwin",
+ "Australia/Eucla",
+ "Australia/Hobart",
+ "Australia/LHI",
+ "Australia/Lindeman",
+ "Australia/Lord_Howe",
+ "Australia/Melbourne",
+ "Australia/NSW",
+ "Australia/North",
+ "Australia/Perth",
+ "Australia/Queensland",
+ "Australia/South",
+ "Australia/Sydney",
+ "Australia/Tasmania",
+ "Australia/Victoria",
+ "Australia/West",
+ "Australia/Yancowinna",
+ "Brazil/Acre",
+ "Brazil/DeNoronha",
+ "Brazil/East",
+ "Brazil/West",
+ "CET",
+ "CST6CDT",
+ "Canada/Atlantic",
+ "Canada/Central",
+ "Canada/Eastern",
+ "Canada/Mountain",
+ "Canada/Newfoundland",
+ "Canada/Pacific",
+ "Canada/Saskatchewan",
+ "Canada/Yukon",
+ "Chile/Continental",
+ "Chile/EasterIsland",
+ "Cuba",
+ "EET",
+ "EST",
+ "EST5EDT",
+ "Egypt",
+ "Eire",
+ "Etc/GMT",
+ "Etc/GMT+0",
+ "Etc/GMT+1",
+ "Etc/GMT+10",
+ "Etc/GMT+11",
+ "Etc/GMT+12",
+ "Etc/GMT+2",
+ "Etc/GMT+3",
+ "Etc/GMT+4",
+ "Etc/GMT+5",
+ "Etc/GMT+6",
+ "Etc/GMT+7",
+ "Etc/GMT+8",
+ "Etc/GMT+9",
+ "Etc/GMT-0",
+ "Etc/GMT-1",
+ "Etc/GMT-10",
+ "Etc/GMT-11",
+ "Etc/GMT-12",
+ "Etc/GMT-13",
+ "Etc/GMT-14",
+ "Etc/GMT-2",
+ "Etc/GMT-3",
+ "Etc/GMT-4",
+ "Etc/GMT-5",
+ "Etc/GMT-6",
+ "Etc/GMT-7",
+ "Etc/GMT-8",
+ "Etc/GMT-9",
+ "Etc/GMT0",
+ "Etc/Greenwich",
+ "Etc/UCT",
+ "Etc/UTC",
+ "Etc/Universal",
+ "Etc/Zulu",
+ "Europe/Amsterdam",
+ "Europe/Andorra",
+ "Europe/Astrakhan",
+ "Europe/Athens",
+ "Europe/Belfast",
+ "Europe/Belgrade",
+ "Europe/Berlin",
+ "Europe/Bratislava",
+ "Europe/Brussels",
+ "Europe/Bucharest",
+ "Europe/Budapest",
+ "Europe/Busingen",
+ "Europe/Chisinau",
+ "Europe/Copenhagen",
+ "Europe/Dublin",
+ "Europe/Gibraltar",
+ "Europe/Guernsey",
+ "Europe/Helsinki",
+ "Europe/Isle_of_Man",
+ "Europe/Istanbul",
+ "Europe/Jersey",
+ "Europe/Kaliningrad",
+ "Europe/Kiev",
+ "Europe/Kirov",
+ "Europe/Kyiv",
+ "Europe/Lisbon",
+ "Europe/Ljubljana",
+ "Europe/London",
+ "Europe/Luxembourg",
+ "Europe/Madrid",
+ "Europe/Malta",
+ "Europe/Mariehamn",
+ "Europe/Minsk",
+ "Europe/Monaco",
+ "Europe/Moscow",
+ "Europe/Nicosia",
+ "Europe/Oslo",
+ "Europe/Paris",
+ "Europe/Podgorica",
+ "Europe/Prague",
+ "Europe/Riga",
+ "Europe/Rome",
+ "Europe/Samara",
+ "Europe/San_Marino",
+ "Europe/Sarajevo",
+ "Europe/Saratov",
+ "Europe/Simferopol",
+ "Europe/Skopje",
+ "Europe/Sofia",
+ "Europe/Stockholm",
+ "Europe/Tallinn",
+ "Europe/Tirane",
+ "Europe/Tiraspol",
+ "Europe/Ulyanovsk",
+ "Europe/Uzhgorod",
+ "Europe/Vaduz",
+ "Europe/Vatican",
+ "Europe/Vienna",
+ "Europe/Vilnius",
+ "Europe/Volgograd",
+ "Europe/Warsaw",
+ "Europe/Zagreb",
+ "Europe/Zaporozhye",
+ "Europe/Zurich",
+ "GB",
+ "GB-Eire",
+ "GMT",
+ "GMT+0",
+ "GMT-0",
+ "GMT0",
+ "Greenwich",
+ "HST",
+ "Hongkong",
+ "Iceland",
+ "Indian/Antananarivo",
+ "Indian/Chagos",
+ "Indian/Christmas",
+ "Indian/Cocos",
+ "Indian/Comoro",
+ "Indian/Kerguelen",
+ "Indian/Mahe",
+ "Indian/Maldives",
+ "Indian/Mauritius",
+ "Indian/Mayotte",
+ "Indian/Reunion",
+ "Iran",
+ "Israel",
+ "Jamaica",
+ "Japan",
+ "Kwajalein",
+ "Libya",
+ "MET",
+ "MST",
+ "MST7MDT",
+ "Mexico/BajaNorte",
+ "Mexico/BajaSur",
+ "Mexico/General",
+ "NZ",
+ "NZ-CHAT",
+ "Navajo",
+ "PRC",
+ "PST8PDT",
+ "Pacific/Apia",
+ "Pacific/Auckland",
+ "Pacific/Bougainville",
+ "Pacific/Chatham",
+ "Pacific/Chuuk",
+ "Pacific/Easter",
+ "Pacific/Efate",
+ "Pacific/Enderbury",
+ "Pacific/Fakaofo",
+ "Pacific/Fiji",
+ "Pacific/Funafuti",
+ "Pacific/Galapagos",
+ "Pacific/Gambier",
+ "Pacific/Guadalcanal",
+ "Pacific/Guam",
+ "Pacific/Honolulu",
+ "Pacific/Johnston",
+ "Pacific/Kanton",
+ "Pacific/Kiritimati",
+ "Pacific/Kosrae",
+ "Pacific/Kwajalein",
+ "Pacific/Majuro",
+ "Pacific/Marquesas",
+ "Pacific/Midway",
+ "Pacific/Nauru",
+ "Pacific/Niue",
+ "Pacific/Norfolk",
+ "Pacific/Noumea",
+ "Pacific/Pago_Pago",
+ "Pacific/Palau",
+ "Pacific/Pitcairn",
+ "Pacific/Pohnpei",
+ "Pacific/Ponape",
+ "Pacific/Port_Moresby",
+ "Pacific/Rarotonga",
+ "Pacific/Saipan",
+ "Pacific/Samoa",
+ "Pacific/Tahiti",
+ "Pacific/Tarawa",
+ "Pacific/Tongatapu",
+ "Pacific/Truk",
+ "Pacific/Wake",
+ "Pacific/Wallis",
+ "Pacific/Yap",
+ "Poland",
+ "Portugal",
+ "ROC",
+ "ROK",
+ "Singapore",
+ "Turkey",
+ "UCT",
+ "US/Alaska",
+ "US/Aleutian",
+ "US/Arizona",
+ "US/Central",
+ "US/East-Indiana",
+ "US/Eastern",
+ "US/Hawaii",
+ "US/Indiana-Starke",
+ "US/Michigan",
+ "US/Mountain",
+ "US/Pacific",
+ "US/Samoa",
+ "UTC",
+ "Universal",
+ "W-SU",
+ "WET",
+ "Zulu",
+ )
+}
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index c883858..9784c63 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -1077,10 +1077,10 @@ class Tokenizer(metaclass=_Tokenizer):
literal = ""
while self._peek.strip() and self._peek not in self.SINGLE_TOKENS:
- literal += self._peek.upper()
+ literal += self._peek
self._advance()
- token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal, ""))
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal.upper(), ""))
if token_type:
self._add(TokenType.NUMBER, number_text)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 8feee52..e0fd68f 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -164,8 +164,9 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
+ """Convert explode/posexplode into unnest (used in hive -> presto)."""
+
def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
- """Convert explode/posexplode into unnest (used in hive -> presto)."""
if isinstance(expression, exp.Select):
from sqlglot.optimizer.scope import Scope
@@ -297,6 +298,7 @@ PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
+ """Transforms percentiles by adding a WITHIN GROUP clause to them."""
if (
isinstance(expression, PERCENTILES)
and not isinstance(expression.parent, exp.WithinGroup)
@@ -311,6 +313,7 @@ def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expressi
def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
+ """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
if (
isinstance(expression, exp.WithinGroup)
and isinstance(expression.this, PERCENTILES)
@@ -324,6 +327,7 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
+ """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
if isinstance(expression, exp.With) and expression.recursive:
next_name = name_sequence("_c_")
@@ -342,6 +346,7 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
+ """Replace 'epoch' in casts by the equivalent date literal."""
if (
isinstance(expression, (exp.Cast, exp.TryCast))
and expression.name.lower() == "epoch"
@@ -352,16 +357,8 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
return expression
-def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
- if isinstance(expression, exp.Timestamp) and not expression.expression:
- return exp.cast(
- expression.this,
- to=exp.DataType.Type.TIMESTAMP,
- )
- return expression
-
-
def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
+ """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
if isinstance(expression, exp.Select):
for join in expression.args.get("joins") or []:
on = join.args.get("on")