summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/bigquery.py2
-rw-r--r--sqlglot/dialects/clickhouse.py56
-rw-r--r--sqlglot/dialects/dialect.py27
-rw-r--r--sqlglot/dialects/duckdb.py2
-rw-r--r--sqlglot/dialects/hive.py41
-rw-r--r--sqlglot/dialects/mysql.py3
-rw-r--r--sqlglot/dialects/postgres.py7
-rw-r--r--sqlglot/dialects/presto.py11
-rw-r--r--sqlglot/dialects/redshift.py20
-rw-r--r--sqlglot/dialects/snowflake.py4
-rw-r--r--sqlglot/dialects/spark2.py23
-rw-r--r--sqlglot/dialects/sqlite.py2
-rw-r--r--sqlglot/dialects/teradata.py17
-rw-r--r--sqlglot/expressions.py72
-rw-r--r--sqlglot/generator.py93
-rw-r--r--sqlglot/optimizer/optimize_joins.py2
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py4
-rw-r--r--sqlglot/parser.py184
-rw-r--r--sqlglot/serde.py8
-rw-r--r--sqlglot/tokens.py45
-rw-r--r--sqlglot/transforms.py11
21 files changed, 465 insertions, 169 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 1a58337..5b10852 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -327,6 +327,8 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"}
+
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index c8a9525..fc48379 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -27,14 +27,15 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
+ STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
- "ASOF": TokenType.ASOF,
"ATTACH": TokenType.COMMAND,
"DATETIME64": TokenType.DATETIME64,
+ "DICTIONARY": TokenType.DICTIONARY,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
"FLOAT64": TokenType.DOUBLE,
@@ -97,7 +98,6 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {
TokenType.ANY,
- TokenType.ASOF,
TokenType.SEMI,
TokenType.ANTI,
TokenType.SETTINGS,
@@ -182,7 +182,7 @@ class ClickHouse(Dialect):
return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
- def _parse_join_side_and_kind(
+ def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
is_global = self._match(TokenType.GLOBAL) and self._prev
@@ -201,7 +201,7 @@ class ClickHouse(Dialect):
join = super()._parse_join(skip_join_token)
if join:
- join.set("global", join.args.pop("natural", None))
+ join.set("global", join.args.pop("method", None))
return join
def _parse_function(
@@ -245,6 +245,23 @@ class ClickHouse(Dialect):
) -> t.List[t.Optional[exp.Expression]]:
return super()._parse_wrapped_id_vars(optional=True)
+ def _parse_primary_key(
+ self, wrapped_optional: bool = False, in_props: bool = False
+ ) -> exp.Expression:
+ return super()._parse_primary_key(
+ wrapped_optional=wrapped_optional or in_props, in_props=in_props
+ )
+
+ def _parse_on_property(self) -> t.Optional[exp.Property]:
+ index = self._index
+ if self._match_text_seq("CLUSTER"):
+ this = self._parse_id_var()
+ if this:
+ return self.expression(exp.OnCluster, this=this)
+ else:
+ self._retreat(index)
+ return None
+
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@@ -292,6 +309,7 @@ class ClickHouse(Dialect):
**generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.OnCluster: exp.Properties.Location.POST_NAME,
}
JOIN_HINTS = False
@@ -299,6 +317,18 @@ class ClickHouse(Dialect):
EXPLICIT_UNION = True
GROUPINGS_SEP = ""
+ # there's no list in docs, but it can be found in Clickhouse code
+ # see `ClickHouse/src/Parsers/ParserCreate*.cpp`
+ ON_CLUSTER_TARGETS = {
+ "DATABASE",
+ "TABLE",
+ "VIEW",
+ "DICTIONARY",
+ "INDEX",
+ "FUNCTION",
+ "NAMED COLLECTION",
+ }
+
def cte_sql(self, expression: exp.CTE) -> str:
if isinstance(expression.this, exp.Alias):
return self.sql(expression, "this")
@@ -321,3 +351,21 @@ class ClickHouse(Dialect):
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f"{{{expression.name}: {self.sql(expression, 'kind')}}}"
+
+ def oncluster_sql(self, expression: exp.OnCluster) -> str:
+ return f"ON CLUSTER {self.sql(expression, 'this')}"
+
+ def createable_sql(
+ self,
+ expression: exp.Create,
+ locations: dict[exp.Properties.Location, list[exp.Property]],
+ ) -> str:
+ kind = self.sql(expression, "kind").upper()
+ if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME):
+ this_name = self.sql(expression.this, "this")
+ this_properties = " ".join(
+ [self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]]
+ )
+ this_schema = self.schema_columns_sql(expression.this)
+ return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}"
+ return super().createable_sql(expression, locations)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 890a3c3..4958bc6 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -104,6 +104,10 @@ class _Dialect(type):
klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING)
klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING)
+ klass.tokenizer_class.identifiers_can_start_with_digit = (
+ klass.identifiers_can_start_with_digit
+ )
+
return klass
@@ -111,6 +115,7 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
+ identifiers_can_start_with_digit = False
normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
@@ -231,6 +236,7 @@ class Dialect(metaclass=_Dialect):
"time_trie": self.inverse_time_trie,
"unnest_column_only": self.unnest_column_only,
"alias_post_tablesample": self.alias_post_tablesample,
+ "identifiers_can_start_with_digit": self.identifiers_can_start_with_digit,
"normalize_functions": self.normalize_functions,
"null_ordering": self.null_ordering,
**opts,
@@ -443,7 +449,7 @@ def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc:
unit = seq_get(args, 0)
this = seq_get(args, 1)
- if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
+ if isinstance(this, exp.Cast) and this.is_type("date"):
return exp.DateTrunc(unit=unit, this=this)
return exp.TimestampTrunc(this=this, unit=unit)
@@ -468,6 +474,25 @@ def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> s
)
+def left_to_substring_sql(self: Generator, expression: exp.Left) -> str:
+ expression = expression.copy()
+ return self.sql(
+ exp.Substring(
+ this=expression.this, start=exp.Literal.number(1), length=expression.expression
+ )
+ )
+
+
+def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
+ expression = expression.copy()
+ return self.sql(
+ exp.Substring(
+ this=expression.this,
+ start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1),
+ )
+ )
+
+
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 662882d..f31da73 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -71,7 +71,7 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
- if expression.this == exp.DataType.Type.ARRAY:
+ if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index fbd626a..650a1e1 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
+ left_to_substring_sql,
locate_to_strposition,
max_or_greatest,
min_or_least,
@@ -17,6 +18,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
rename_func,
+ right_to_substring_sql,
strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
@@ -89,7 +91,7 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s
annotate_types(this)
- if this.type.is_type(exp.DataType.Type.JSON):
+ if this.type.is_type("json"):
return self.sql(this)
return self.func("TO_JSON", this, expression.args.get("options"))
@@ -149,6 +151,7 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
class Hive(Dialect):
alias_post_tablesample = True
+ identifiers_can_start_with_digit = True
time_mapping = {
"y": "%Y",
@@ -190,7 +193,6 @@ class Hive(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
- IDENTIFIER_CAN_START_WITH_DIGIT = True
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -276,6 +278,39 @@ class Hive(Dialect):
"cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"),
}
+ def _parse_types(
+ self, check_func: bool = False, schema: bool = False
+ ) -> t.Optional[exp.Expression]:
+ """
+ Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to
+ STRING in all contexts except for schema definitions. For example, this is in Spark v3.4.0:
+
+ spark-sql (default)> select cast(1234 as varchar(2));
+ 23/06/06 15:51:18 WARN CharVarcharUtils: The Spark cast operator does not support
+ char/varchar type and simply treats them as string type. Please use string type
+ directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString
+ to true, so that Spark treat them as string type as same as Spark 3.0 and earlier
+
+ 1234
+ Time taken: 4.265 seconds, Fetched 1 row(s)
+
+ This shows that Spark doesn't truncate the value into '12', which is inconsistent with
+ what other dialects (e.g. postgres) do, so we need to drop the length to transpile correctly.
+
+ Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
+ """
+ this = super()._parse_types(check_func=check_func, schema=schema)
+
+ if this and not schema:
+ return this.transform(
+ lambda node: node.replace(exp.DataType.build("text"))
+ if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
+ else node,
+ copy=False,
+ )
+
+ return this
+
class Generator(generator.Generator):
LIMIT_FETCH = "LIMIT"
TABLESAMPLE_WITH_METHOD = False
@@ -323,6 +358,7 @@ class Hive(Dialect):
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: _json_format_sql,
+ exp.Left: left_to_substring_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
@@ -332,6 +368,7 @@ class Hive(Dialect):
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
+ exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 2b41860..75023ff 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -186,9 +186,6 @@ class MySQL(Dialect):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
- "LEFT": lambda args: exp.Substring(
- this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
- ),
"LOCATE": locate_to_strposition,
"STR_TO_DATE": _str_to_date,
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index ab61880..8d84024 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -18,7 +18,9 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
timestamptrunc_sql,
+ timestrtotime_sql,
trim_sql,
+ ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import binary_range_parser
@@ -104,7 +106,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
- if expression.this == exp.DataType.Type.ARRAY:
+ if expression.is_type("array"):
return f"{self.expressions(expression, flat=True)}[]"
return self.datatype_sql(expression)
@@ -353,12 +355,13 @@ class Postgres(Dialect):
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
exp.TimestampTrunc: timestamptrunc_sql,
- exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
+ exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 52a04a4..d839864 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -8,10 +8,12 @@ from sqlglot.dialects.dialect import (
date_trunc_to_time,
format_time_lambda,
if_sql,
+ left_to_substring_sql,
no_ilike_sql,
no_pivot_sql,
no_safe_divide_sql,
rename_func,
+ right_to_substring_sql,
struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
@@ -30,7 +32,7 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
sql = self.datatype_sql(expression)
- if expression.this == exp.DataType.Type.TIMESTAMPTZ:
+ if expression.is_type("timestamptz"):
sql = f"{sql} WITH TIME ZONE"
return sql
@@ -240,6 +242,7 @@ class Presto(Dialect):
INTERVAL_ALLOWS_PLURAL_FORM = False
JOIN_HINTS = False
TABLE_HINTS = False
+ IS_BOOL = False
STRUCT_DELIMITER = ("(", ")")
PROPERTIES_LOCATION = {
@@ -272,6 +275,7 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: self.func(
@@ -292,11 +296,13 @@ class Presto(Dialect):
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
+ exp.Left: left_to_substring_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
+ exp.Right: right_to_substring_sql,
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.Select: transforms.preprocess(
@@ -319,6 +325,7 @@ class Presto(Dialect):
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToUnix: rename_func("TO_UNIXTIME"),
+ exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]),
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
@@ -356,7 +363,7 @@ class Presto(Dialect):
else:
target_type = None
- if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP):
+ if target_type and target_type.is_type("timestamp"):
to = target_type.copy()
if target_type is start.to:
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 55e393a..b0a6774 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, transforms
+from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -24,26 +25,29 @@ class Redshift(Postgres):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS,
"DATEADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2),
+ this=exp.TsOrDsToDate(this=seq_get(args, 2)),
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
"DATEDIFF": lambda args: exp.DateDiff(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
+ this=exp.TsOrDsToDate(this=seq_get(args, 2)),
+ expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
unit=seq_get(args, 0),
),
"NVL": exp.Coalesce.from_arg_list,
+ "STRTOL": exp.FromBase.from_arg_list,
}
CONVERT_TYPE_FIRST = True
- def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
- this = super()._parse_types(check_func=check_func)
+ def _parse_types(
+ self, check_func: bool = False, schema: bool = False
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_types(check_func=check_func, schema=schema)
if (
isinstance(this, exp.DataType)
- and this.this == exp.DataType.Type.VARCHAR
+ and this.is_type("varchar")
and this.expressions
and this.expressions[0].this == exp.column("MAX")
):
@@ -99,10 +103,12 @@ class Redshift(Postgres):
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
+ exp.FromBase: rename_func("STRTOL"),
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.TsOrDsToDate: lambda self, e: self.sql(e.this),
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
@@ -158,7 +164,7 @@ class Redshift(Postgres):
without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
`TEXT` to `VARCHAR`.
"""
- if expression.this == exp.DataType.Type.TEXT:
+ if expression.is_type("text"):
expression = expression.copy()
expression.set("this", exp.DataType.Type.VARCHAR)
precision = expression.args.get("expressions")
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 756e8e9..821d991 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -153,9 +153,9 @@ def _nullifzero_to_if(args: t.List) -> exp.Expression:
def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str:
- if expression.this == exp.DataType.Type.ARRAY:
+ if expression.is_type("array"):
return "ARRAY"
- elif expression.this == exp.DataType.Type.MAP:
+ elif expression.is_type("map"):
return "OBJECT"
return self.datatype_sql(expression)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 912b86b..bf24240 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -110,11 +110,6 @@ class Spark2(Hive):
**Hive.Parser.FUNCTIONS,
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
- "LEFT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Literal.number(1),
- length=seq_get(args, 1),
- ),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
this=seq_get(args, 0),
expression=seq_get(args, 1),
@@ -123,14 +118,6 @@ class Spark2(Hive):
this=seq_get(args, 0),
expression=seq_get(args, 1),
),
- "RIGHT": lambda args: exp.Substring(
- this=seq_get(args, 0),
- start=exp.Sub(
- this=exp.Length(this=seq_get(args, 0)),
- expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
- ),
- length=seq_get(args, 1),
- ),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
@@ -240,17 +227,17 @@ class Spark2(Hive):
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
+ TRANSFORMS.pop(exp.Left)
+ TRANSFORMS.pop(exp.Right)
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
- if isinstance(expression.this, exp.Cast) and expression.this.is_type(
- exp.DataType.Type.JSON
- ):
+ if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"):
schema = f"'{self.sql(expression, 'to')}'"
return self.func("FROM_JSON", expression.this.this, schema)
- if expression.to.is_type(exp.DataType.Type.JSON):
+ if expression.is_type("json"):
return self.func("TO_JSON", expression.this)
return super(Hive.Generator, self).cast_sql(expression)
@@ -260,7 +247,7 @@ class Spark2(Hive):
expression,
sep=": "
if isinstance(expression.parent, exp.DataType)
- and expression.parent.is_type(exp.DataType.Type.STRUCT)
+ and expression.parent.is_type("struct")
else sep,
)
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 56e7773..4e800b0 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -132,7 +132,7 @@ class SQLite(Dialect):
LIMIT_FETCH = "LIMIT"
def cast_sql(self, expression: exp.Cast) -> str:
- if expression.to.this == exp.DataType.Type.DATE:
+ if expression.is_type("date"):
return self.func("DATE", expression.this)
return super().cast_sql(expression)
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 9b39178..514aecb 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -183,3 +183,20 @@ class Teradata(Dialect):
each_sql = f" EACH {each_sql}" if each_sql else ""
return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"
+
+ def createable_sql(
+ self,
+ expression: exp.Create,
+ locations: dict[exp.Properties.Location, list[exp.Property]],
+ ) -> str:
+ kind = self.sql(expression, "kind").upper()
+ if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME):
+ this_name = self.sql(expression.this, "this")
+ this_properties = self.properties(
+ exp.Properties(expressions=locations[exp.Properties.Location.POST_NAME]),
+ wrapped=False,
+ prefix=",",
+ )
+ this_schema = self.schema_columns_sql(expression.this)
+ return f"{this_name}{this_properties}{self.sep()}{this_schema}"
+ return super().createable_sql(expression, locations)
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index a4c4e95..da4a4ed 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1653,12 +1653,16 @@ class Join(Expression):
"side": False,
"kind": False,
"using": False,
- "natural": False,
+ "method": False,
"global": False,
"hint": False,
}
@property
+ def method(self) -> str:
+ return self.text("method").upper()
+
+ @property
def kind(self) -> str:
return self.text("kind").upper()
@@ -1913,6 +1917,24 @@ class LanguageProperty(Property):
arg_types = {"this": True}
+class DictProperty(Property):
+ arg_types = {"this": True, "kind": True, "settings": False}
+
+
+class DictSubProperty(Property):
+ pass
+
+
+class DictRange(Property):
+ arg_types = {"this": True, "min": True, "max": True}
+
+
+# Clickhouse CREATE ... ON CLUSTER modifier
+# https://clickhouse.com/docs/en/sql-reference/distributed-ddl
+class OnCluster(Property):
+ arg_types = {"this": True}
+
+
class LikeProperty(Property):
arg_types = {"this": True, "expressions": False}
@@ -2797,12 +2819,12 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
- parse_args = {"dialect": dialect, **opts}
+ parse_args: t.Dict[str, t.Any] = {"dialect": dialect, **opts}
try:
- expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore
+ expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args)
except ParseError:
- expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore
+ expression = maybe_parse(expression, into=(Join, Expression), **parse_args)
join = expression if isinstance(expression, Join) else Join(this=expression)
@@ -2810,14 +2832,14 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
- natural: t.Optional[Token]
+ method: t.Optional[Token]
side: t.Optional[Token]
kind: t.Optional[Token]
- natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
+ method, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
- if natural:
- join.set("natural", True)
+ if method:
+ join.set("method", method.text)
if side:
join.set("side", side.text)
if kind:
@@ -3222,6 +3244,18 @@ class DataType(Expression):
DATE = auto()
DATETIME = auto()
DATETIME64 = auto()
+ INT4RANGE = auto()
+ INT4MULTIRANGE = auto()
+ INT8RANGE = auto()
+ INT8MULTIRANGE = auto()
+ NUMRANGE = auto()
+ NUMMULTIRANGE = auto()
+ TSRANGE = auto()
+ TSMULTIRANGE = auto()
+ TSTZRANGE = auto()
+ TSTZMULTIRANGE = auto()
+ DATERANGE = auto()
+ DATEMULTIRANGE = auto()
DECIMAL = auto()
DOUBLE = auto()
FLOAT = auto()
@@ -3331,8 +3365,8 @@ class DataType(Expression):
return DataType(**{**data_type_exp.args, **kwargs})
- def is_type(self, dtype: DataType.Type) -> bool:
- return self.this == dtype
+ def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
+ return any(self.this == DataType.build(dtype).this for dtype in dtypes)
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@@ -3846,8 +3880,8 @@ class Cast(Func):
def output_name(self) -> str:
return self.name
- def is_type(self, dtype: DataType.Type) -> bool:
- return self.to.is_type(dtype)
+ def is_type(self, *dtypes: str | DataType | DataType.Type) -> bool:
+ return self.to.is_type(*dtypes)
class CastToStrType(Func):
@@ -4130,8 +4164,16 @@ class Least(Func):
is_var_len_args = True
+class Left(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class Right(Func):
+ arg_types = {"this": True, "expression": True}
+
+
class Length(Func):
- pass
+ _sql_names = ["LENGTH", "LEN"]
class Levenshtein(Func):
@@ -4356,6 +4398,10 @@ class NumberToStr(Func):
arg_types = {"this": True, "format": True}
+class FromBase(Func):
+ arg_types = {"this": True, "expression": True}
+
+
class Struct(Func):
arg_types = {"expressions": True}
is_var_len_args = True
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index f1ec398..97cbe15 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -44,6 +44,8 @@ class Generator:
Default: "upper"
alias_post_tablesample (bool): if the table alias comes after tablesample
Default: False
+ identifiers_can_start_with_digit (bool): if an unquoted identifier can start with digit
+ Default: False
unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
unsupported expressions. Default ErrorLevel.WARN.
null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
@@ -188,6 +190,8 @@ class Generator:
exp.Cluster: exp.Properties.Location.POST_SCHEMA,
exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME,
exp.DefinerProperty: exp.Properties.Location.POST_CREATE,
+ exp.DictRange: exp.Properties.Location.POST_SCHEMA,
+ exp.DictProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA,
exp.EngineProperty: exp.Properties.Location.POST_SCHEMA,
@@ -233,6 +237,7 @@ class Generator:
JOIN_HINTS = True
TABLE_HINTS = True
+ IS_BOOL = True
RESERVED_KEYWORDS: t.Set[str] = set()
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
@@ -264,6 +269,7 @@ class Generator:
"index_offset",
"unnest_column_only",
"alias_post_tablesample",
+ "identifiers_can_start_with_digit",
"normalize_functions",
"unsupported_level",
"unsupported_messages",
@@ -304,6 +310,7 @@ class Generator:
index_offset=0,
unnest_column_only=False,
alias_post_tablesample=False,
+ identifiers_can_start_with_digit=False,
normalize_functions="upper",
unsupported_level=ErrorLevel.WARN,
null_ordering=None,
@@ -337,6 +344,7 @@ class Generator:
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
self.alias_post_tablesample = alias_post_tablesample
+ self.identifiers_can_start_with_digit = identifiers_can_start_with_digit
self.normalize_functions = normalize_functions
self.unsupported_level = unsupported_level
self.unsupported_messages = []
@@ -634,35 +642,31 @@ class Generator:
this = f" {this}" if this else ""
return f"UNIQUE{this}"
+ def createable_sql(
+ self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]]
+ ) -> str:
+ return self.sql(expression, "this")
+
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
properties = expression.args.get("properties")
- properties_exp = expression.copy()
properties_locs = self.locate_properties(properties) if properties else {}
+
+ this = self.createable_sql(expression, properties_locs)
+
+ properties_sql = ""
if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get(
exp.Properties.Location.POST_WITH
):
- properties_exp.set(
- "properties",
+ properties_sql = self.sql(
exp.Properties(
expressions=[
*properties_locs[exp.Properties.Location.POST_SCHEMA],
*properties_locs[exp.Properties.Location.POST_WITH],
]
- ),
+ )
)
- if kind == "TABLE" and properties_locs.get(exp.Properties.Location.POST_NAME):
- this_name = self.sql(expression.this, "this")
- this_properties = self.properties(
- exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_NAME]),
- wrapped=False,
- )
- this_schema = f"({self.expressions(expression.this)})"
- this = f"{this_name}, {this_properties} {this_schema}"
- properties_sql = ""
- else:
- this = self.sql(expression, "this")
- properties_sql = self.sql(properties_exp, "properties")
+
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
if expression_sql:
@@ -894,6 +898,7 @@ class Generator:
expression.quoted
or should_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
+ or (not self.identifiers_can_start_with_digit and text[:1].isdigit())
):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@@ -1082,7 +1087,7 @@ class Generator:
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
- this: str = f" {this}" if expression.this else ""
+ this = f" {self.sql(expression, 'this')}" if expression.this else ""
for_or_in = expression.args.get("for_or_in")
lock_type = expression.args.get("lock_type")
override = " OVERRIDE" if expression.args.get("override") else ""
@@ -1313,7 +1318,7 @@ class Generator:
op_sql = " ".join(
op
for op in (
- "NATURAL" if expression.args.get("natural") else None,
+ expression.method,
"GLOBAL" if expression.args.get("global") else None,
expression.side,
expression.kind,
@@ -1573,9 +1578,12 @@ class Generator:
def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
- sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
+ sql = self.schema_columns_sql(expression)
return f"{this}{sql}"
+ def schema_columns_sql(self, expression: exp.Schema) -> str:
+ return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
+
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
@@ -1643,32 +1651,26 @@ class Generator:
def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
-
partition = self.partition_by_sql(expression)
-
order = expression.args.get("order")
- order_sql = self.order_sql(order, flat=True) if order else ""
-
- partition_sql = partition + " " if partition and order else partition
-
- spec = expression.args.get("spec")
- spec_sql = " " + self.windowspec_sql(spec) if spec else ""
-
+ order = self.order_sql(order, flat=True) if order else ""
+ spec = self.sql(expression, "spec")
alias = self.sql(expression, "alias")
over = self.sql(expression, "over") or "OVER"
+
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
first = expression.args.get("first")
- if first is not None:
- first = " FIRST " if first else " LAST "
- first = first or ""
+ if first is None:
+ first = ""
+ else:
+ first = "FIRST" if first else "LAST"
if not partition and not order and not spec and alias:
return f"{this} {alias}"
- window_args = alias + first + partition_sql + order_sql + spec_sql
-
- return f"{this} ({window_args.strip()})"
+ args = " ".join(arg for arg in (alias, first, partition, order, spec) if arg)
+ return f"{this} ({args})"
def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str:
partition = self.expressions(expression, key="partition_by", flat=True)
@@ -2125,6 +2127,10 @@ class Generator:
return self.binary(expression, "ILIKE ANY")
def is_sql(self, expression: exp.Is) -> str:
+ if not self.IS_BOOL and isinstance(expression.expression, exp.Boolean):
+ return self.sql(
+ expression.this if expression.expression.this else exp.not_(expression.this)
+ )
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
@@ -2322,6 +2328,25 @@ class Generator:
return self.sql(exp.cast(expression.this, "text"))
+ def dictproperty_sql(self, expression: exp.DictProperty) -> str:
+ this = self.sql(expression, "this")
+ kind = self.sql(expression, "kind")
+ settings_sql = self.expressions(expression, key="settings", sep=" ")
+ args = f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" if settings_sql else "()"
+ return f"{this}({kind}{args})"
+
+ def dictrange_sql(self, expression: exp.DictRange) -> str:
+ this = self.sql(expression, "this")
+ max = self.sql(expression, "max")
+ min = self.sql(expression, "min")
+ return f"{this}(MIN {min} MAX {max})"
+
+ def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str:
+ return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}"
+
+ def oncluster_sql(self, expression: exp.OnCluster) -> str:
+ return ""
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 43436cb..4e0c3a1 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,7 +1,7 @@
from sqlglot import exp
from sqlglot.helper import tsort
-JOIN_ATTRS = ("on", "side", "kind", "using", "natural")
+JOIN_ATTRS = ("on", "side", "kind", "using", "method")
def optimize_joins(expression):
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 96dda33..b89a82b 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -10,10 +10,10 @@ def pushdown_predicates(expression):
Example:
>>> import sqlglot
- >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
+ >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_predicates(expression).sql()
- 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
+ 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE'
Args:
expression (sqlglot.Expression): expression to optimize
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index e77bb5a..96bd6e3 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -155,6 +155,18 @@ class Parser(metaclass=_Parser):
TokenType.DATETIME,
TokenType.DATETIME64,
TokenType.DATE,
+ TokenType.INT4RANGE,
+ TokenType.INT4MULTIRANGE,
+ TokenType.INT8RANGE,
+ TokenType.INT8MULTIRANGE,
+ TokenType.NUMRANGE,
+ TokenType.NUMMULTIRANGE,
+ TokenType.TSRANGE,
+ TokenType.TSMULTIRANGE,
+ TokenType.TSTZRANGE,
+ TokenType.TSTZMULTIRANGE,
+ TokenType.DATERANGE,
+ TokenType.DATEMULTIRANGE,
TokenType.DECIMAL,
TokenType.BIGDECIMAL,
TokenType.UUID,
@@ -193,6 +205,7 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
TokenType.TABLE,
TokenType.VIEW,
+ TokenType.DICTIONARY,
}
CREATABLES = {
@@ -220,6 +233,7 @@ class Parser(metaclass=_Parser):
TokenType.DELETE,
TokenType.DESC,
TokenType.DESCRIBE,
+ TokenType.DICTIONARY,
TokenType.DIV,
TokenType.END,
TokenType.EXECUTE,
@@ -272,6 +286,7 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {
TokenType.APPLY,
+ TokenType.ASOF,
TokenType.FULL,
TokenType.LEFT,
TokenType.LOCK,
@@ -375,6 +390,11 @@ class Parser(metaclass=_Parser):
TokenType.EXCEPT,
}
+ JOIN_METHODS = {
+ TokenType.NATURAL,
+ TokenType.ASOF,
+ }
+
JOIN_SIDES = {
TokenType.LEFT,
TokenType.RIGHT,
@@ -465,7 +485,7 @@ class Parser(metaclass=_Parser):
exp.Where: lambda self: self._parse_where(),
exp.Window: lambda self: self._parse_named_window(),
exp.With: lambda self: self._parse_with(),
- "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
+ "JOIN_TYPE": lambda self: self._parse_join_parts(),
}
STATEMENT_PARSERS = {
@@ -580,6 +600,8 @@ class Parser(metaclass=_Parser):
),
"JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs),
"LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty),
+ "LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"),
+ "LIFETIME": lambda self: self._parse_dict_range(this="LIFETIME"),
"LIKE": lambda self: self._parse_create_like(),
"LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty),
"LOCK": lambda self: self._parse_locking(),
@@ -594,7 +616,8 @@ class Parser(metaclass=_Parser):
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
- "PRIMARY KEY": lambda self: self._parse_primary_key(),
+ "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True),
+ "RANGE": lambda self: self._parse_dict_range(this="RANGE"),
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
@@ -603,6 +626,7 @@ class Parser(metaclass=_Parser):
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
),
"SORTKEY": lambda self: self._parse_sortkey(),
+ "SOURCE": lambda self: self._parse_dict_property(this="SOURCE"),
"STABLE": lambda self: self.expression(
exp.StabilityProperty, this=exp.Literal.string("STABLE")
),
@@ -1133,13 +1157,16 @@ class Parser(metaclass=_Parser):
begin = None
clone = None
+ def extend_props(temp_props: t.Optional[exp.Expression]) -> None:
+ nonlocal properties
+ if properties and temp_props:
+ properties.expressions.extend(temp_props.expressions)
+ elif temp_props:
+ properties = temp_props
+
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function(kind=create_token.token_type)
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
begin = self._match(TokenType.BEGIN)
@@ -1154,21 +1181,13 @@ class Parser(metaclass=_Parser):
table_parts = self._parse_table_parts(schema=True)
# exp.Properties.Location.POST_NAME
- if self._match(TokenType.COMMA):
- temp_properties = self._parse_properties(before=True)
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ self._match(TokenType.COMMA)
+ extend_props(self._parse_properties(before=True))
this = self._parse_schema(this=table_parts)
# exp.Properties.Location.POST_SCHEMA and POST_WITH
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ extend_props(self._parse_properties())
self._match(TokenType.ALIAS)
@@ -1178,11 +1197,7 @@ class Parser(metaclass=_Parser):
or self._match(TokenType.WITH, advance=False)
or self._match(TokenType.L_PAREN, advance=False)
):
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ extend_props(self._parse_properties())
expression = self._parse_ddl_select()
@@ -1192,11 +1207,7 @@ class Parser(metaclass=_Parser):
index = self._parse_index()
# exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX
- temp_properties = self._parse_properties()
- if properties and temp_properties:
- properties.expressions.extend(temp_properties.expressions)
- elif temp_properties:
- properties = temp_properties
+ extend_props(self._parse_properties())
if not index:
break
@@ -1888,8 +1899,16 @@ class Parser(metaclass=_Parser):
this = self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
- this = self._parse_table() if table else self._parse_select(nested=True)
- this = self._parse_set_operations(self._parse_query_modifiers(this))
+ if self._match(TokenType.PIVOT):
+ this = self._parse_simplified_pivot()
+ elif self._match(TokenType.FROM):
+ this = exp.select("*").from_(
+ t.cast(exp.From, self._parse_from(skip_from_token=True))
+ )
+ else:
+ this = self._parse_table() if table else self._parse_select(nested=True)
+ this = self._parse_set_operations(self._parse_query_modifiers(this))
+
self._match_r_paren()
# early return so that subquery unions aren't parsed again
@@ -1902,10 +1921,6 @@ class Parser(metaclass=_Parser):
expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
- elif self._match(TokenType.PIVOT):
- this = self._parse_simplified_pivot()
- elif self._match(TokenType.FROM):
- this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True)))
else:
this = None
@@ -2154,11 +2169,11 @@ class Parser(metaclass=_Parser):
return expression
- def _parse_join_side_and_kind(
+ def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
return (
- self._match(TokenType.NATURAL) and self._prev,
+ self._match_set(self.JOIN_METHODS) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
@@ -2168,14 +2183,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Join, this=self._parse_table())
index = self._index
- natural, side, kind = self._parse_join_side_and_kind()
+ method, side, kind = self._parse_join_parts()
hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None
join = self._match(TokenType.JOIN)
if not skip_join_token and not join:
self._retreat(index)
kind = None
- natural = None
+ method = None
side = None
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False)
@@ -2187,12 +2202,10 @@ class Parser(metaclass=_Parser):
if outer_apply:
side = Token(TokenType.LEFT, "LEFT")
- kwargs: t.Dict[
- str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
- ] = {"this": self._parse_table()}
+ kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()}
- if natural:
- kwargs["natural"] = True
+ if method:
+ kwargs["method"] = method.text
if side:
kwargs["side"] = side.text
if kind:
@@ -2205,7 +2218,7 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
- return self.expression(exp.Join, **kwargs) # type: ignore
+ return self.expression(exp.Join, **kwargs)
def _parse_index(
self,
@@ -2886,7 +2899,9 @@ class Parser(metaclass=_Parser):
exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True)
)
- def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
+ def _parse_types(
+ self, check_func: bool = False, schema: bool = False
+ ) -> t.Optional[exp.Expression]:
index = self._index
prefix = self._match_text_seq("SYSUDTLIB", ".")
@@ -2908,7 +2923,9 @@ class Parser(metaclass=_Parser):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
elif nested:
- expressions = self._parse_csv(self._parse_types)
+ expressions = self._parse_csv(
+ lambda: self._parse_types(check_func=check_func, schema=schema)
+ )
else:
expressions = self._parse_csv(self._parse_type_size)
@@ -2943,7 +2960,9 @@ class Parser(metaclass=_Parser):
if is_struct:
expressions = self._parse_csv(self._parse_struct_types)
else:
- expressions = self._parse_csv(self._parse_types)
+ expressions = self._parse_csv(
+ lambda: self._parse_types(check_func=check_func, schema=schema)
+ )
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
@@ -3038,11 +3057,7 @@ class Parser(metaclass=_Parser):
else exp.Literal.string(value)
)
else:
- field = (
- self._parse_star()
- or self._parse_function(anonymous=True)
- or self._parse_id_var()
- )
+ field = self._parse_field(anonymous_func=True)
if isinstance(field, exp.Func):
# bigquery allows function calls like x.y.count(...)
@@ -3113,10 +3128,11 @@ class Parser(metaclass=_Parser):
self,
any_token: bool = False,
tokens: t.Optional[t.Collection[TokenType]] = None,
+ anonymous_func: bool = False,
) -> t.Optional[exp.Expression]:
return (
self._parse_primary()
- or self._parse_function()
+ or self._parse_function(anonymous=anonymous_func)
or self._parse_id_var(any_token=any_token, tokens=tokens)
)
@@ -3270,7 +3286,7 @@ class Parser(metaclass=_Parser):
# column defs are not really columns, they're identifiers
if isinstance(this, exp.Column):
this = this.this
- kind = self._parse_types()
+ kind = self._parse_types(schema=True)
if self._match_text_seq("FOR", "ORDINALITY"):
return self.expression(exp.ColumnDef, this=this, ordinality=True)
@@ -3483,16 +3499,18 @@ class Parser(metaclass=_Parser):
exp.ForeignKey, expressions=expressions, reference=reference, **options # type: ignore
)
- def _parse_primary_key(self) -> exp.Expression:
+ def _parse_primary_key(
+ self, wrapped_optional: bool = False, in_props: bool = False
+ ) -> exp.Expression:
desc = (
self._match_set((TokenType.ASC, TokenType.DESC))
and self._prev.token_type == TokenType.DESC
)
- if not self._match(TokenType.L_PAREN, advance=False):
+ if not in_props and not self._match(TokenType.L_PAREN, advance=False):
return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc)
- expressions = self._parse_wrapped_csv(self._parse_field)
+ expressions = self._parse_wrapped_csv(self._parse_field, optional=wrapped_optional)
options = self._parse_key_constraint_options()
return self.expression(exp.PrimaryKey, expressions=expressions, options=options)
@@ -3509,10 +3527,11 @@ class Parser(metaclass=_Parser):
return this
bracket_kind = self._prev.token_type
- expressions: t.List[t.Optional[exp.Expression]]
if self._match(TokenType.COLON):
- expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
+ expressions: t.List[t.Optional[exp.Expression]] = [
+ self.expression(exp.Slice, expression=self._parse_conjunction())
+ ]
else:
expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction()))
@@ -4011,22 +4030,15 @@ class Parser(metaclass=_Parser):
self,
any_token: bool = True,
tokens: t.Optional[t.Collection[TokenType]] = None,
- prefix_tokens: t.Optional[t.Collection[TokenType]] = None,
) -> t.Optional[exp.Expression]:
identifier = self._parse_identifier()
if identifier:
return identifier
- prefix = ""
-
- if prefix_tokens:
- while self._match_set(prefix_tokens):
- prefix += self._prev.text
-
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
quoted = self._prev.token_type == TokenType.STRING
- return exp.Identifier(this=prefix + self._prev.text, quoted=quoted)
+ return exp.Identifier(this=self._prev.text, quoted=quoted)
return None
@@ -4472,6 +4484,44 @@ class Parser(metaclass=_Parser):
size = len(start.text)
return exp.Command(this=text[:size], expression=text[size:])
+ def _parse_dict_property(self, this: str) -> exp.DictProperty:
+ settings = []
+
+ self._match_l_paren()
+ kind = self._parse_id_var()
+
+ if self._match(TokenType.L_PAREN):
+ while True:
+ key = self._parse_id_var()
+ value = self._parse_primary()
+
+ if not key and value is None:
+ break
+ settings.append(self.expression(exp.DictSubProperty, this=key, value=value))
+ self._match(TokenType.R_PAREN)
+
+ self._match_r_paren()
+
+ return self.expression(
+ exp.DictProperty,
+ this=this,
+ kind=kind.this if kind else None,
+ settings=settings,
+ )
+
+ def _parse_dict_range(self, this: str) -> exp.DictRange:
+ self._match_l_paren()
+ has_min = self._match_text_seq("MIN")
+ if has_min:
+ min = self._parse_var() or self._parse_primary()
+ self._match_text_seq("MAX")
+ max = self._parse_var() or self._parse_primary()
+ else:
+ max = self._parse_var() or self._parse_primary()
+ min = exp.Literal.number(0)
+ self._match_r_paren()
+ return self.expression(exp.DictRange, this=this, min=min, max=max)
+
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
) -> t.Optional[t.Callable]:
diff --git a/sqlglot/serde.py b/sqlglot/serde.py
index c5203a7..b019035 100644
--- a/sqlglot/serde.py
+++ b/sqlglot/serde.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
if t.TYPE_CHECKING:
- JSON = t.Union[dict, list, str, float, int, bool]
+ JSON = t.Union[dict, list, str, float, int, bool, None]
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
@@ -24,12 +24,12 @@ def dump(node: Node) -> JSON:
klass = node.__class__.__qualname__
if node.__class__.__module__ != exp.__name__:
klass = f"{node.__module__}.{klass}"
- obj = {
+ obj: t.Dict = {
"class": klass,
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
}
if node.type:
- obj["type"] = node.type.sql()
+ obj["type"] = dump(node.type)
if node.comments:
obj["comments"] = node.comments
if node._meta is not None:
@@ -60,7 +60,7 @@ def load(obj: JSON) -> Node:
klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
- expression.type = obj.get("type")
+ expression.type = t.cast(exp.DataType, load(obj.get("type")))
expression.comments = obj.get("comments")
expression._meta = obj.get("meta")
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index ad329d2..a30ec24 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -113,6 +113,18 @@ class TokenType(AutoName):
DATETIME = auto()
DATETIME64 = auto()
DATE = auto()
+ INT4RANGE = auto()
+ INT4MULTIRANGE = auto()
+ INT8RANGE = auto()
+ INT8MULTIRANGE = auto()
+ NUMRANGE = auto()
+ NUMMULTIRANGE = auto()
+ TSRANGE = auto()
+ TSMULTIRANGE = auto()
+ TSTZRANGE = auto()
+ TSTZMULTIRANGE = auto()
+ DATERANGE = auto()
+ DATEMULTIRANGE = auto()
UUID = auto()
GEOGRAPHY = auto()
NULLABLE = auto()
@@ -167,6 +179,7 @@ class TokenType(AutoName):
DELETE = auto()
DESC = auto()
DESCRIBE = auto()
+ DICTIONARY = auto()
DISTINCT = auto()
DIV = auto()
DROP = auto()
@@ -480,6 +493,7 @@ class Tokenizer(metaclass=_Tokenizer):
"ANY": TokenType.ANY,
"ASC": TokenType.ASC,
"AS": TokenType.ALIAS,
+ "ASOF": TokenType.ASOF,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
"AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
"BEGIN": TokenType.BEGIN,
@@ -669,6 +683,18 @@ class Tokenizer(metaclass=_Tokenizer):
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
"DATE": TokenType.DATE,
"DATETIME": TokenType.DATETIME,
+ "INT4RANGE": TokenType.INT4RANGE,
+ "INT4MULTIRANGE": TokenType.INT4MULTIRANGE,
+ "INT8RANGE": TokenType.INT8RANGE,
+ "INT8MULTIRANGE": TokenType.INT8MULTIRANGE,
+ "NUMRANGE": TokenType.NUMRANGE,
+ "NUMMULTIRANGE": TokenType.NUMMULTIRANGE,
+ "TSRANGE": TokenType.TSRANGE,
+ "TSMULTIRANGE": TokenType.TSMULTIRANGE,
+ "TSTZRANGE": TokenType.TSTZRANGE,
+ "TSTZMULTIRANGE": TokenType.TSTZMULTIRANGE,
+ "DATERANGE": TokenType.DATERANGE,
+ "DATEMULTIRANGE": TokenType.DATEMULTIRANGE,
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
@@ -709,8 +735,6 @@ class Tokenizer(metaclass=_Tokenizer):
COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")]
KEYWORD_TRIE: t.Dict = {} # autofilled
- IDENTIFIER_CAN_START_WITH_DIGIT = False
-
__slots__ = (
"sql",
"size",
@@ -724,6 +748,7 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
+ "identifiers_can_start_with_digit",
)
def __init__(self) -> None:
@@ -826,6 +851,12 @@ class Tokenizer(metaclass=_Tokenizer):
def _text(self) -> str:
return self.sql[self._start : self._current]
+ def peek(self, i: int = 0) -> str:
+ i = self._current + i
+ if i < self.size:
+ return self.sql[i]
+ return ""
+
def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self.tokens.append(
@@ -962,8 +993,12 @@ class Tokenizer(metaclass=_Tokenizer):
if self._peek.isdigit():
self._advance()
elif self._peek == "." and not decimal:
- decimal = True
- self._advance()
+ after = self.peek(1)
+ if after.isdigit() or not after.strip():
+ decimal = True
+ self._advance()
+ else:
+ return self._add(TokenType.VAR)
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
@@ -984,7 +1019,7 @@ class Tokenizer(metaclass=_Tokenizer):
self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal)
- elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
+ elif self.identifiers_can_start_with_digit: # type: ignore
return self._add(TokenType.VAR)
self._add(TokenType.NUMBER, number_text)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index a1ec1bd..ba72616 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -268,6 +268,17 @@ def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression
return expression
+def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
+ if (
+ isinstance(expression, (exp.Cast, exp.TryCast))
+ and expression.name.lower() == "epoch"
+ and expression.to.this in exp.DataType.TEMPORAL_TYPES
+ ):
+ expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]: