summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:45 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:45 +0000
commita34653eb21369376f0e054dd989311afcb167f5b (patch)
tree5a0280adce195af0be654f79fd99395fd2932c19 /sqlglot
parentReleasing debian version 18.7.0-1. (diff)
downloadsqlglot-a34653eb21369376f0e054dd989311afcb167f5b.tar.xz
sqlglot-a34653eb21369376f0e054dd989311afcb167f5b.zip
Merging upstream version 18.11.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/bigquery.py30
-rw-r--r--sqlglot/dialects/clickhouse.py30
-rw-r--r--sqlglot/dialects/databricks.py20
-rw-r--r--sqlglot/dialects/dialect.py9
-rw-r--r--sqlglot/dialects/hive.py33
-rw-r--r--sqlglot/dialects/mysql.py107
-rw-r--r--sqlglot/dialects/oracle.py8
-rw-r--r--sqlglot/dialects/postgres.py14
-rw-r--r--sqlglot/dialects/presto.py3
-rw-r--r--sqlglot/dialects/redshift.py2
-rw-r--r--sqlglot/dialects/snowflake.py47
-rw-r--r--sqlglot/dialects/spark.py11
-rw-r--r--sqlglot/dialects/spark2.py26
-rw-r--r--sqlglot/dialects/tsql.py9
-rw-r--r--sqlglot/executor/env.py1
-rw-r--r--sqlglot/expressions.py81
-rw-r--r--sqlglot/generator.py67
-rw-r--r--sqlglot/optimizer/annotate_types.py115
-rw-r--r--sqlglot/optimizer/canonicalize.py19
-rw-r--r--sqlglot/optimizer/merge_subqueries.py2
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py18
-rw-r--r--sqlglot/optimizer/simplify.py97
-rw-r--r--sqlglot/parser.py87
-rw-r--r--sqlglot/tokens.py18
24 files changed, 701 insertions, 153 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 1349c56..0d741b5 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -190,6 +190,16 @@ class BigQuery(Dialect):
"%D": "%m/%d/%y",
}
+ ESCAPE_SEQUENCES = {
+ "\\a": "\a",
+ "\\b": "\b",
+ "\\f": "\f",
+ "\\n": "\n",
+ "\\r": "\r",
+ "\\t": "\t",
+ "\\v": "\v",
+ }
+
FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
@@ -212,15 +222,14 @@ class BigQuery(Dialect):
@classmethod
def normalize_identifier(cls, expression: E) -> E:
- # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
- # The following check is essentially a heuristic to detect tables based on whether or
- # not they're qualified.
if isinstance(expression, exp.Identifier):
parent = expression.parent
-
while isinstance(parent, exp.Dot):
parent = parent.parent
+ # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
+ # The following check is essentially a heuristic to detect tables based on whether or
+ # not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive.
if (
not isinstance(parent, exp.UserDefinedFunction)
and not (isinstance(parent, exp.Table) and parent.db)
@@ -419,6 +428,7 @@ class BigQuery(Dialect):
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
UNNEST_WITH_ORDINALITY = False
+ COLLATE_IS_FUNC = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -520,18 +530,6 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore
- {
- "\a": "\\a",
- "\b": "\\b",
- "\f": "\\f",
- "\n": "\\n",
- "\r": "\\r",
- "\t": "\\t",
- "\v": "\\v",
- }
- )
-
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
RESERVED_KEYWORDS = {
*generator.Generator.RESERVED_KEYWORDS,
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 7446081..e9d9326 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
@@ -21,18 +21,33 @@ def _lower_func(sql: str) -> str:
return sql[:index].lower() + sql[index:]
+def _quantile_sql(self, e):
+ quantile = e.args["quantile"]
+ args = f"({self.sql(e, 'this')})"
+ if isinstance(quantile, exp.Array):
+ func = self.func("quantiles", *quantile)
+ else:
+ func = self.func("quantile", quantile)
+ return func + args
+
+
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
STRICT_STRING_CONCAT = True
SUPPORTS_USER_DEFINED_TYPES = False
+ ESCAPE_SEQUENCES = {
+ "\\0": "\0",
+ }
+
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
+ HEREDOC_STRINGS = ["$"]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -55,6 +70,7 @@ class ClickHouse(Dialect):
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
"MAP": TokenType.MAP,
"NESTED": TokenType.NESTED,
+ "SAMPLE": TokenType.TABLE_SAMPLE,
"TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
"UINT16": TokenType.USMALLINT,
@@ -64,6 +80,11 @@ class ClickHouse(Dialect):
"UINT8": TokenType.UTINYINT,
}
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.HEREDOC_STRING,
+ }
+
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -301,6 +322,7 @@ class ClickHouse(Dialect):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
+ TABLESAMPLE_REQUIRES_PARENS = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -348,6 +370,7 @@ class ClickHouse(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
@@ -359,12 +382,13 @@ class ClickHouse(Dialect):
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
+ exp.IsNan: rename_func("isNaN"),
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
- exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
- + f"({self.sql(e, 'this')})",
+ exp.Quantile: _quantile_sql,
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
+ exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 39daad7..a044bc0 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -51,6 +51,26 @@ class Databricks(Spark):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
+ constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
+ kind = expression.args.get("kind")
+ if (
+ constraint
+ and isinstance(kind, exp.DataType)
+ and kind.this in exp.DataType.INTEGER_TYPES
+ ):
+ # only BIGINT generated identity constraints are supported
+ expression = expression.copy()
+ expression.set("kind", exp.DataType.build("bigint"))
+ return super().columndef_sql(expression, sep)
+
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
+ expression = expression.copy()
+ expression.set("this", True) # trigger ALWAYS in super class
+ return super().generatedasidentitycolumnconstraint_sql(expression)
+
class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index ccf04da..bd839af 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -81,6 +81,8 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
+ klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
+
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
@@ -188,6 +190,9 @@ class Dialect(metaclass=_Dialect):
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
+ # Mapping of an unescaped escape sequence to the corresponding character
+ ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+
# Columns that are auto-generated by the engine corresponding to this dialect
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
@@ -204,6 +209,8 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
+ INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
@@ -245,7 +252,7 @@ class Dialect(metaclass=_Dialect):
"""
Normalizes an unquoted identifier to either lower or upper case, thus essentially
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
- they will be normalized regardless of being quoted or not.
+ they will be normalized to lowercase regardless of being quoted or not.
"""
if isinstance(expression, exp.Identifier) and (
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index a427870..3f925a7 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -51,6 +51,32 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
+def _create_sql(self, expression: exp.Create) -> str:
+ expression = expression.copy()
+
+ # remove UNIQUE column constraints
+ for constraint in expression.find_all(exp.UniqueColumnConstraint):
+ if constraint.parent:
+ constraint.parent.pop()
+
+ properties = expression.args.get("properties")
+ temporary = any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ )
+
+ # CTAS with temp tables map to CREATE TEMPORARY VIEW
+ kind = expression.args["kind"]
+ if kind.upper() == "TABLE" and temporary:
+ if expression.expression:
+ return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
+ else:
+ # CREATE TEMPORARY TABLE may require storage provider
+ expression = self.temporary_storage_provider(expression)
+
+ return create_with_partitions_sql(self, expression)
+
+
def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
@@ -429,7 +455,7 @@ class Hive(Dialect):
if e.args.get("allow_null")
else "NOT NULL",
exp.VarMap: var_map_sql,
- exp.Create: create_with_partitions_sql,
+ exp.Create: _create_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
@@ -478,8 +504,13 @@ class Hive(Dialect):
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
+ # Hive has no temporary storage provider (there are hive settings though)
+ return expression
+
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
parent = expression.parent
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 554241d..59a0a2a 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
+def _str_to_date_sql(
+ self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
@@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(
+ kind: str,
+) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
+ def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D
return func
+def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
+ time_format = expression.args.get("format")
+ if time_format:
+ return _str_to_date_sql(self, expression)
+ return f"DATE({self.sql(expression, 'this')})"
+
+
+def _remove_ts_or_ds_to_date(
+ to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None,
+ args: t.Tuple[str, ...] = ("this",),
+) -> t.Callable[[MySQL.Generator, exp.Func], str]:
+ def func(self: MySQL.Generator, expression: exp.Func) -> str:
+ expression = expression.copy()
+
+ for arg_key in args:
+ arg = expression.args.get(arg_key)
+ if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
+ expression.set(arg_key, arg.this)
+
+ return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression)
+
+ return func
+
+
class MySQL(Dialect):
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
IDENTIFIERS_CAN_START_WITH_DIGIT = True
@@ -233,6 +261,7 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
@@ -240,14 +269,33 @@ class MySQL(Dialect):
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MONTHNAME": lambda args: exp.TimeToStr(
- this=seq_get(args, 0),
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
+ "TO_DAYS": lambda args: exp.paren(
+ exp.DateDiff(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")),
+ unit=exp.var("DAY"),
+ )
+ + 1
+ ),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "WEEK": lambda args: exp.Week(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
+ ),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ "CHAR": lambda self: self._parse_chr(),
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -531,6 +579,18 @@ class MySQL(Dialect):
return super()._parse_type(parse_interval=parse_interval)
+ def _parse_chr(self) -> t.Optional[exp.Expression]:
+ expressions = self._parse_csv(self._parse_conjunction)
+ kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)}
+
+ if len(expressions) > 1:
+ kwargs["expressions"] = expressions[1:]
+
+ if self._match(TokenType.USING):
+ kwargs["charset"] = self._parse_var()
+
+ return self.expression(exp.Chr, **kwargs)
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@@ -544,25 +604,33 @@ class MySQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
- exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
- exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateDiff: _remove_ts_or_ds_to_date(
+ lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
+ ),
+ exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _date_add_sql("SUB"),
+ exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
- exp.DayOfMonth: rename_func("DAYOFMONTH"),
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.Day: _remove_ts_or_ds_to_date(),
+ exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
+ exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
+ exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@@ -573,10 +641,16 @@ class MySQL(Dialect):
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
- exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
+ exp.TimeToStr: _remove_ts_or_ds_to_date(
+ lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
+ ),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
+ exp.TsOrDsAdd: _date_add_sql("ADD"),
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Week: _remove_ts_or_ds_to_date(),
+ exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
+ exp.Year: _remove_ts_or_ds_to_date(),
}
UNSIGNED_TYPE_MAPPING = {
@@ -585,6 +659,7 @@ class MySQL(Dialect):
exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
exp.DataType.Type.USMALLINT: "SMALLINT",
exp.DataType.Type.UTINYINT: "TINYINT",
+ exp.DataType.Type.UDECIMAL: "DECIMAL",
}
TIMESTAMP_TYPE_MAPPING = {
@@ -717,3 +792,9 @@ class MySQL(Dialect):
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
+
+ def chr_sql(self, expression: exp.Chr) -> str:
+ this = self.expressions(sqls=[expression.this] + expression.expressions)
+ charset = expression.args.get("charset")
+ using = f" USING {self.sql(charset)}" if charset else ""
+ return f"CHAR({this}{using})"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 0a4926d..6a007ab 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -153,6 +153,7 @@ class Oracle(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
COLUMN_JOIN_MARKS_SUPPORTED = True
+ DATA_TYPE_SPECIFIERS_ALLOWED = True
LIMIT_FETCH = "FETCH"
@@ -179,7 +180,12 @@ class Oracle(Dialect):
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.ILike: no_ilike_sql,
- exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
+ exp.Select: transforms.preprocess(
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_qualify,
+ ]
+ ),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 342fd95..008727c 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -22,6 +22,7 @@ from sqlglot.dialects.dialect import (
rename_func,
simplify_literal,
str_position_sql,
+ struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
trim_sql,
@@ -248,11 +249,10 @@ class Postgres(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'", "$$"]
-
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
+ HEREDOC_STRINGS = ["$"]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -296,7 +296,7 @@ class Postgres(Dialect):
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
- "$": TokenType.PARAMETER,
+ "$": TokenType.HEREDOC_STRING,
}
VAR_SINGLE_TOKENS = {"$"}
@@ -420,9 +420,15 @@ class Postgres(Dialect):
exp.Pow: lambda self, e: self.binary(e, "^"),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
- exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
+ exp.Select: transforms.preprocess(
+ [
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
+ ),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 0d8d4ab..e5cfa1c 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -309,6 +309,9 @@ class Presto(Dialect):
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
+ exp.GroupConcat: lambda self, e: self.func(
+ "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
+ ),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 2145844..88e4448 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -83,7 +83,7 @@ class Redshift(Postgres):
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
- STRING_ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\", "'"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 5c49331..fc3e0fa 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -239,6 +239,8 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
@@ -318,6 +320,43 @@ class Snowflake(Dialect):
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
}
+ STAGED_FILE_SINGLE_TOKENS = {
+ TokenType.DOT,
+ TokenType.MOD,
+ TokenType.SLASH,
+ }
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ # https://docs.snowflake.com/en/user-guide/querying-stage
+ table: t.Optional[exp.Expression] = None
+ if self._match_text_seq("@"):
+ table_name = "@"
+ while True:
+ self._advance()
+ table_name += self._prev.text
+ if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
+ break
+ while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
+ table_name += self._prev.text
+
+ table = exp.var(table_name)
+ elif self._match(TokenType.STRING, advance=False):
+ table = self._parse_string()
+
+ if table:
+ file_format = None
+ pattern = None
+
+ if self._match_text_seq("(", "FILE_FORMAT", "=>"):
+ file_format = self._parse_string() or super()._parse_table_parts()
+ if self._match_text_seq(",", "PATTERN", "=>"):
+ pattern = self._parse_string()
+ self._match_r_paren()
+
+ return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
+
+ return super()._parse_table_parts(schema=schema)
+
def _parse_id_var(
self,
any_token: bool = True,
@@ -394,6 +433,8 @@ class Snowflake(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
AGGREGATE_FILTER_SUPPORTED = False
+ SUPPORTS_TABLE_COPY = False
+ COLLATE_IS_FUNC = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -423,6 +464,12 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.PercentileCont: transforms.preprocess(
+ [transforms.add_within_group_for_percentiles]
+ ),
+ exp.PercentileDisc: transforms.preprocess(
+ [transforms.add_within_group_for_percentiles]
+ ),
exp.RegexpILike: _regexpilike_sql,
exp.Select: transforms.preprocess(
[
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 9d4a1ab..2eaa2ae 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -54,6 +54,14 @@ class Spark(Spark2):
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("ANY_VALUE")
+ def _parse_generated_as_identity(
+ self,
+ ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
+ this = super()._parse_generated_as_identity()
+ if this.expression:
+ return self.expression(exp.ComputedColumnConstraint, this=this.expression)
+ return this
+
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
@@ -73,6 +81,9 @@ class Spark(Spark2):
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
+ def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
+ return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
+
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
return self.function_fallback_sql(expression)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 3dc9838..4130375 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -5,7 +5,6 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
- create_with_partitions_sql,
format_time_lambda,
is_parse_json,
move_insert_cte_sql,
@@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
- kind = e.args["kind"]
- properties = e.args.get("properties")
-
- if (
- kind.upper() == "TABLE"
- and e.expression
- and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- )
- ):
- return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
- return create_with_partitions_sql(self, e)
-
-
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
keys = expression.args.get("keys")
values = expression.args.get("values")
@@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
class Spark2(Hive):
class Parser(Hive.Parser):
+ TRIM_PATTERN_FIRST = True
+
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
"AGGREGATE": exp.Reduce.from_arg_list,
@@ -192,7 +177,6 @@ class Spark2(Hive):
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.Create: _create_sql,
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -236,6 +220,12 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
+ def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
+ # spark2, spark, Databricks require a storage provider for temporary tables
+ provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
+ expression.args["properties"].append("expressions", provider)
+ return expression
+
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index fa62e78..6aa49e4 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
parse_date_delta,
rename_func,
timestrtotime_sql,
+ ts_or_ds_to_date_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -590,6 +591,7 @@ class TSQL(Dialect):
NVL2_SUPPORTED = False
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
+ COMPUTED_COLUMN_WITH_TYPE = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -619,7 +621,11 @@ class TSQL(Dialect):
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
@@ -630,6 +636,7 @@ class TSQL(Dialect):
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
TRANSFORMS.pop(exp.ReturnsProperty)
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 9f63100..bf2941c 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -202,4 +202,5 @@ ENV = {
"CURRENTTIME": datetime.datetime.now,
"CURRENTDATE": datetime.date.today,
"STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)),
+ "TRIM": null_if_any(lambda this, e=None: this.strip(e)),
}
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 8e9575e..1e4aad6 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -52,6 +52,9 @@ class _Expression(type):
return klass
+SQLGLOT_META = "sqlglot.meta"
+
+
class Expression(metaclass=_Expression):
"""
The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
@@ -266,7 +269,14 @@ class Expression(metaclass=_Expression):
if self.comments is None:
self.comments = []
if comments:
- self.comments.extend(comments)
+ for comment in comments:
+ _, *meta = comment.split(SQLGLOT_META)
+ if meta:
+ for kv in "".join(meta).split(","):
+ k, *v = kv.split("=")
+ value = v[0].strip() if v else True
+ self.meta[k.strip()] = value
+ self.comments.append(comment)
def append(self, arg_key: str, value: t.Any) -> None:
"""
@@ -1036,11 +1046,14 @@ class Create(DDL):
"indexes": False,
"no_schema_binding": False,
"begin": False,
+ "end": False,
"clone": False,
}
# https://docs.snowflake.com/en/sql-reference/sql/create-clone
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement
+# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy
class Clone(Expression):
arg_types = {
"this": True,
@@ -1048,6 +1061,7 @@ class Clone(Expression):
"kind": False,
"shallow": False,
"expression": False,
+ "copy": False,
}
@@ -1610,6 +1624,11 @@ class Identifier(Expression):
return self.name
+# https://www.postgresql.org/docs/current/indexes-opclass.html
+class Opclass(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
class Index(Expression):
arg_types = {
"this": False,
@@ -2156,6 +2175,10 @@ class QueryTransform(Expression):
}
+class SampleProperty(Property):
+ arg_types = {"this": True}
+
+
class SchemaCommentProperty(Property):
arg_types = {"this": True}
@@ -2440,6 +2463,8 @@ class Table(Expression):
"hints": False,
"system_time": False,
"version": False,
+ "format": False,
+ "pattern": False,
}
@property
@@ -2465,17 +2490,17 @@ class Table(Expression):
return []
@property
- def parts(self) -> t.List[Identifier]:
+ def parts(self) -> t.List[Expression]:
"""Return the parts of a table in order catalog, db, table."""
- parts: t.List[Identifier] = []
+ parts: t.List[Expression] = []
for arg in ("catalog", "db", "this"):
part = self.args.get(arg)
- if isinstance(part, Identifier):
- parts.append(part)
- elif isinstance(part, Dot):
+ if isinstance(part, Dot):
parts.extend(part.flatten())
+ elif isinstance(part, Expression):
+ parts.append(part)
return parts
@@ -2910,6 +2935,7 @@ class Select(Subqueryable):
prefix="OFFSET",
dialect=dialect,
copy=copy,
+ into_arg="expression",
**opts,
)
@@ -3572,6 +3598,7 @@ class DataType(Expression):
UINT128 = auto()
UINT256 = auto()
UMEDIUMINT = auto()
+ UDECIMAL = auto()
UNIQUEIDENTIFIER = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
USERDEFINED = "USER-DEFINED"
@@ -3693,13 +3720,13 @@ class DataType(Expression):
# https://www.postgresql.org/docs/15/datatype-pseudo.html
-class PseudoType(Expression):
- pass
+class PseudoType(DataType):
+ arg_types = {"this": True}
# https://www.postgresql.org/docs/15/datatype-oid.html
-class ObjectIdentifier(Expression):
- pass
+class ObjectIdentifier(DataType):
+ arg_types = {"this": True}
# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
@@ -4027,10 +4054,20 @@ class TimeUnit(Expression):
return self.args.get("unit")
+class IntervalOp(TimeUnit):
+ arg_types = {"unit": True, "expression": True}
+
+ def interval(self):
+ return Interval(
+ this=self.expression.copy(),
+ unit=self.unit.copy(),
+ )
+
+
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
# https://trino.io/docs/current/language/types.html#interval-day-to-second
# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html
-class IntervalSpan(Expression):
+class IntervalSpan(DataType):
arg_types = {"this": True, "expression": True}
@@ -4269,7 +4306,7 @@ class CastToStrType(Func):
arg_types = {"this": True, "to": True}
-class Collate(Binary):
+class Collate(Binary, Func):
pass
@@ -4284,6 +4321,12 @@ class Coalesce(Func):
_sql_names = ["COALESCE", "IFNULL", "NVL"]
+class Chr(Func):
+ arg_types = {"this": True, "charset": False, "expressions": False}
+ is_var_len_args = True
+ _sql_names = ["CHR", "CHAR"]
+
+
class Concat(Func):
arg_types = {"expressions": True}
is_var_len_args = True
@@ -4326,11 +4369,11 @@ class CurrentUser(Func):
arg_types = {"this": False}
-class DateAdd(Func, TimeUnit):
+class DateAdd(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}
-class DateSub(Func, TimeUnit):
+class DateSub(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -4347,11 +4390,11 @@ class DateTrunc(Func):
return self.args["unit"]
-class DatetimeAdd(Func, TimeUnit):
+class DatetimeAdd(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}
-class DatetimeSub(Func, TimeUnit):
+class DatetimeSub(Func, IntervalOp):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -4375,6 +4418,10 @@ class DayOfYear(Func):
_sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"]
+class ToDays(Func):
+ pass
+
+
class WeekOfYear(Func):
_sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"]
@@ -6160,7 +6207,7 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str:
The table name.
"""
- table = maybe_parse(table, into=Table)
+ table = maybe_parse(table, into=Table, dialect=dialect)
if not table:
raise ValueError(f"Cannot parse {table}")
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index b1ee783..edc6939 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -86,6 +86,7 @@ class Generator:
exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
+ exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}",
exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET",
exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}",
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
@@ -204,6 +205,21 @@ class Generator:
# Whether or not session variables / parameters are supported, e.g. @x in T-SQL
SUPPORTS_PARAMETERS = True
+ # Whether or not to include the type of a computed column in the CREATE DDL
+ COMPUTED_COLUMN_WITH_TYPE = True
+
+ # Whether or not CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY
+ SUPPORTS_TABLE_COPY = True
+
+ # Whether or not parentheses are required around the table sample's expression
+ TABLESAMPLE_REQUIRES_PARENS = True
+
+ # Whether or not COLLATE is a function instead of a binary operator
+ COLLATE_IS_FUNC = False
+
+ # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
+ DATA_TYPE_SPECIFIERS_ALLOWED = False
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -282,6 +298,7 @@ class Generator:
exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.SampleProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA,
exp.Set: exp.Properties.Location.POST_SCHEMA,
@@ -324,13 +341,12 @@ class Generator:
exp.Paren,
)
- UNESCAPED_SEQUENCE_TABLE = None # type: ignore
-
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
# Autofilled
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
+ INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
INDEX_OFFSET = 0
UNNEST_COLUMN_ONLY = False
ALIAS_POST_TABLESAMPLE = False
@@ -480,8 +496,7 @@ class Generator:
if not comments or isinstance(expression, exp.Binary):
return sql
- sep = "\n" if self.pretty else " "
- comments_sql = sep.join(
+ comments_sql = " ".join(
f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
)
@@ -649,6 +664,9 @@ class Generator:
position = self.sql(expression, "position")
position = f" {position}" if position else ""
+ if expression.find(exp.ComputedColumnConstraint) and not self.COMPUTED_COLUMN_WITH_TYPE:
+ kind = ""
+
return f"{exists}{column}{kind}{constraints}{position}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
@@ -750,9 +768,11 @@ class Generator:
)
begin = " BEGIN" if expression.args.get("begin") else ""
+ end = " END" if expression.args.get("end") else ""
+
expression_sql = self.sql(expression, "expression")
if expression_sql:
- expression_sql = f"{begin}{self.sep()}{expression_sql}"
+ expression_sql = f"{begin}{self.sep()}{expression_sql}{end}"
if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return):
if properties_locs.get(exp.Properties.Location.POST_ALIAS):
@@ -817,7 +837,8 @@ class Generator:
def clone_sql(self, expression: exp.Clone) -> str:
this = self.sql(expression, "this")
shallow = "SHALLOW " if expression.args.get("shallow") else ""
- this = f"{shallow}CLONE {this}"
+ keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE"
+ this = f"{shallow}{keyword} {this}"
when = self.sql(expression, "when")
if when:
@@ -877,7 +898,7 @@ class Generator:
def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str:
this = self.sql(expression, "this")
specifier = self.sql(expression, "expression")
- specifier = f" {specifier}" if specifier else ""
+ specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else ""
return f"{this}{specifier}"
def datatype_sql(self, expression: exp.DataType) -> str:
@@ -1329,8 +1350,13 @@ class Generator:
pivots = f" {pivots}" if pivots else ""
joins = self.expressions(expression, key="joins", sep="", skip_first=True)
laterals = self.expressions(expression, key="laterals", sep="")
+ file_format = self.sql(expression, "format")
+ if file_format:
+ pattern = self.sql(expression, "pattern")
+ pattern = f", PATTERN => {pattern}" if pattern else ""
+ file_format = f" (FILE_FORMAT => {file_format}{pattern})"
- return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}"
+ return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@@ -1343,6 +1369,7 @@ class Generator:
else:
this = self.sql(expression, "this")
alias = ""
+
method = self.sql(expression, "method")
method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
@@ -1354,13 +1381,20 @@ class Generator:
percent = f"{percent} PERCENT" if percent else ""
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
+
size = self.sql(expression, "size")
if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
size = f"{size} PERCENT"
+
seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
- return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}"
+
+ expr = f"{bucket}{percent}{rows}{size}"
+ if self.TABLESAMPLE_REQUIRES_PARENS:
+ expr = f"({expr})"
+
+ return f"{this} {kind} {method}{expr}{seed}{alias}"
def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
@@ -1638,8 +1672,8 @@ class Generator:
def escape_str(self, text: str) -> str:
text = text.replace(self.QUOTE_END, self._escaped_quote_end)
- if self.UNESCAPED_SEQUENCE_TABLE:
- text = text.translate(self.UNESCAPED_SEQUENCE_TABLE)
+ if self.INVERSE_ESCAPE_SEQUENCES:
+ text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text)
elif self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
return text
@@ -2301,6 +2335,8 @@ class Generator:
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
def collate_sql(self, expression: exp.Collate) -> str:
+ if self.COLLATE_IS_FUNC:
+ return self.function_fallback_sql(expression)
return self.binary(expression, "COLLATE")
def command_sql(self, expression: exp.Command) -> str:
@@ -2359,7 +2395,7 @@ class Generator:
collate = f" COLLATE {collate}" if collate else ""
using = self.sql(expression, "using")
using = f" USING {using}" if using else ""
- return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}"
+ return f"ALTER COLUMN {this} SET DATA TYPE {dtype}{collate}{using}"
default = self.sql(expression, "default")
if default:
@@ -2396,7 +2432,7 @@ class Generator:
elif isinstance(actions[0], exp.Delete):
actions = self.expressions(expression, key="actions", flat=True)
else:
- actions = self.expressions(expression, key="actions")
+ actions = self.expressions(expression, key="actions", flat=True)
exists = " IF EXISTS" if expression.args.get("exists") else ""
only = " ONLY" if expression.args.get("only") else ""
@@ -2593,7 +2629,7 @@ class Generator:
self,
expression: t.Optional[exp.Expression] = None,
key: t.Optional[str] = None,
- sqls: t.Optional[t.List[str]] = None,
+ sqls: t.Optional[t.Collection[str | exp.Expression]] = None,
flat: bool = False,
indent: bool = True,
skip_first: bool = False,
@@ -2841,6 +2877,9 @@ class Generator:
def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str:
return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})"
+ def opclass_sql(self, expression: exp.Opclass) -> str:
+ return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
+
def cached_generator(
cache: t.Optional[t.Dict[int, str]] = None
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index afc6995..17af6ac 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import datetime
+import functools
import typing as t
from sqlglot import exp
@@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
+ BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
+ BinaryCoercions = t.Dict[
+ t.Tuple[exp.DataType.Type, exp.DataType.Type],
+ BinaryCoercionFunc,
+ ]
+
+
+# Interval units that operate on date components
+DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
+
def annotate_types(
expression: E,
@@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)
+def _is_iso_date(text: str) -> bool:
+ try:
+ datetime.date.fromisoformat(text)
+ return True
+ except ValueError:
+ return False
+
+
+def _is_iso_datetime(text: str) -> bool:
+ try:
+ datetime.datetime.fromisoformat(text)
+ return True
+ except ValueError:
+ return False
+
+
+def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+ date_text = l.name
+ unit = r.text("unit").lower()
+
+ is_iso_date = _is_iso_date(date_text)
+
+ if is_iso_date and unit in DATE_UNITS:
+ l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
+ return exp.DataType.Type.DATE
+
+ # An ISO date is also an ISO datetime, but not vice versa
+ if is_iso_date or _is_iso_datetime(date_text):
+ l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
+ return exp.DataType.Type.DATETIME
+
+ return exp.DataType.Type.UNKNOWN
+
+
+def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+ unit = r.text("unit").lower()
+ if unit not in DATE_UNITS:
+ return exp.DataType.Type.DATETIME
+ return l.type.this if l.type else exp.DataType.Type.UNKNOWN
+
+
+def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
+ @functools.wraps(func)
+ def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+ return func(r, l)
+
+ return _swapped
+
+
+def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
+ return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
+
+
class _TypeAnnotator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DATE: {
exp.CurrentDate,
exp.Date,
- exp.DateAdd,
exp.DateFromParts,
exp.DateStrToDate,
- exp.DateSub,
exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
@@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
+ exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
+ exp.DateSub: lambda self, e: self._annotate_dateadd(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
@@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
+ # Coercion functions for binary operations.
+ # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
+ BINARY_COERCIONS: BinaryCoercions = {
+ **swap_all(
+ {
+ (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
+ for t in exp.DataType.TEXT_TYPES
+ }
+ ),
+ **swap_all(
+ {
+ (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
+ }
+ ),
+ }
+
def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+ binary_coercions: t.Optional[BinaryCoercions] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
+ self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
self._visited: t.Set[int] = set()
- def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
- expression.type = target_type
+ def _set_type(
+ self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
+ ) -> None:
+ expression.type = target_type # type: ignore
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
@@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
- left_type = expression.left.type.this
- right_type = expression.right.type.this
+ left, right = expression.left, expression.right
+ left_type, right_type = left.type.this, right.type.this
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif isinstance(expression, exp.Predicate):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
+ elif (left_type, right_type) in self.binary_coercions:
+ self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
@@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
return expression
+
+ def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
+ self._annotate_args(expression)
+
+ if expression.this.type.this in exp.DataType.TEXT_TYPES:
+ datatype = _coerce_literal_and_interval(expression.this, expression.interval())
+ elif (
+ expression.this.type.is_type(exp.DataType.Type.DATE)
+ and expression.text("unit").lower() not in DATE_UNITS
+ ):
+ datatype = exp.DataType.Type.DATETIME
+ else:
+ datatype = expression.this.type
+
+ self._set_type(expression, datatype)
+ return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index e45d1e3..ec3b3af 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
- elif isinstance(node, exp.Extract):
- if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
- _replace_cast(node.expression, "datetime")
+ elif isinstance(node, exp.Extract) and not node.expression.type.is_type(
+ *exp.DataType.TEMPORAL_TYPES
+ ):
+ _replace_cast(node.expression, exp.DataType.Type.DATETIME)
+
return node
@@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
_replace_int_predicate(expression.left)
_replace_int_predicate(expression.right)
- elif isinstance(expression, (exp.Where, exp.Having)):
+ elif isinstance(expression, (exp.Where, exp.Having, exp.If)):
_replace_int_predicate(expression.this)
return expression
@@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
and b.type
and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
):
- _replace_cast(b, "date")
+ _replace_cast(b, exp.DataType.Type.DATE)
-def _replace_cast(node: exp.Expression, to: str) -> None:
+def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))
def _replace_int_predicate(expression: exp.Expression) -> None:
- if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
+ if isinstance(expression, exp.Coalesce):
+ for _, child in expression.iter_expressions():
+ _replace_int_predicate(child)
+ elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 976c9ad..b0b2b3d 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -181,7 +181,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not outer_scope.pivots
- and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
+ and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions)
and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1)
and not (
isinstance(from_or_join, exp.Join)
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 54cf02b..32f3a92 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -22,6 +22,13 @@ def normalize_identifiers(expression, dialect=None):
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.
+ It's possible to make this a no-op by adding a special comment next to the
+ identifier of interest:
+
+ SELECT a /* sqlglot.meta case_sensitive */ FROM table
+
+ In this example, the identifier `a` will not be normalized.
+
Note:
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
when they're quoted, so in these cases all identifiers are normalized.
@@ -43,4 +50,13 @@ def normalize_identifiers(expression, dialect=None):
"""
if isinstance(expression, str):
expression = exp.to_identifier(expression)
- return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
+
+ dialect = Dialect.get_or_raise(dialect)
+
+ def _normalize(node: E) -> E:
+ if not node.meta.get("case_sensitive"):
+ exp.replace_children(node, _normalize)
+ node = dialect.normalize_identifier(node)
+ return node
+
+ return _normalize(expression)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d08c692..51214c4 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool:
return expression.is_number
-def _is_date(expression: exp.Expression) -> bool:
- return isinstance(expression, exp.Cast) and extract_date(expression) is not None
-
-
def _is_interval(expression: exp.Expression) -> bool:
return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
@@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression:
if r.is_number:
a_predicate = _is_number
b_predicate = _is_number
- elif _is_date(r):
- a_predicate = _is_date
+ elif _is_date_literal(r):
+ a_predicate = _is_date_literal
b_predicate = _is_interval
else:
return expression
if l.__class__ in INVERSE_DATE_OPS:
a = l.this
- b = exp.Interval(
- this=l.expression.copy(),
- unit=l.unit.copy(),
- )
+ b = l.interval()
else:
a, b = l.left, l.right
@@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b):
if boolean:
return boolean
- elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
+ elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
return date_literal(a - b)
- elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
+ elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
@@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
- return (
- isinstance(left, (exp.DateTrunc, exp.TimestampTrunc))
- and isinstance(right, exp.Cast)
- and right.is_type(*exp.DataType.TEMPORAL_TYPES)
- )
+ return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
@catch(ModuleNotFoundError, UnsupportedUnit)
@@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
unit = l.unit.name.lower()
date = extract_date(r)
+ if not date:
+ return expression
+
return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
elif isinstance(expression, exp.In):
l = expression.this
rs = expression.expressions
- if all(_is_datetrunc_predicate(l, r) for r in rs):
+ if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
unit = l.unit.name.lower()
- ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r]
+ ranges = []
+ for r in rs:
+ date = extract_date(r)
+ if not date:
+ return expression
+ drange = _datetrunc_range(date, unit)
+ if drange:
+ ranges.append(drange)
+
if not ranges:
return expression
@@ -811,18 +811,59 @@ def eval_boolean(expression, a, b):
return None
-def extract_date(cast):
- # The "fromisoformat" conversion could fail if the cast is used on an identifier,
- # so in that case we can't extract the date.
+def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
+ if isinstance(value, datetime.datetime):
+ return value.date()
+ if isinstance(value, datetime.date):
+ return value
try:
- if cast.args["to"].this == exp.DataType.Type.DATE:
- return datetime.date.fromisoformat(cast.name)
- if cast.args["to"].this == exp.DataType.Type.DATETIME:
- return datetime.datetime.fromisoformat(cast.name)
+ return datetime.datetime.fromisoformat(value).date()
except ValueError:
return None
+def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
+ if isinstance(value, datetime.datetime):
+ return value
+ if isinstance(value, datetime.date):
+ return datetime.datetime(year=value.year, month=value.month, day=value.day)
+ try:
+ return datetime.datetime.fromisoformat(value)
+ except ValueError:
+ return None
+
+
+def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
+ if not value:
+ return None
+ if to.is_type(exp.DataType.Type.DATE):
+ return cast_as_date(value)
+ if to.is_type(*exp.DataType.TEMPORAL_TYPES):
+ return cast_as_datetime(value)
+ return None
+
+
+def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
+ if isinstance(cast, exp.Cast):
+ to = cast.to
+ elif isinstance(cast, exp.TsOrDsToDate):
+ to = exp.DataType.build(exp.DataType.Type.DATE)
+ else:
+ return None
+
+ if isinstance(cast.this, exp.Literal):
+ value: t.Any = cast.this.name
+ elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
+ value = extract_date(cast.this)
+ else:
+ return None
+ return cast_value(value, to)
+
+
+def _is_date_literal(expression: exp.Expression) -> bool:
+ return extract_date(expression) is not None
+
+
def extract_interval(expression):
n = int(expression.name)
unit = expression.text("unit").lower()
@@ -836,7 +877,9 @@ def extract_interval(expression):
def date_literal(date):
return exp.cast(
exp.Literal.string(date),
- "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
+ exp.DataType.Type.DATETIME
+ if isinstance(date, datetime.datetime)
+ else exp.DataType.Type.DATE,
)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 84b2639..5e56961 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -178,6 +178,7 @@ class Parser(metaclass=_Parser):
TokenType.DATERANGE,
TokenType.DATEMULTIRANGE,
TokenType.DECIMAL,
+ TokenType.UDECIMAL,
TokenType.BIGDECIMAL,
TokenType.UUID,
TokenType.GEOGRAPHY,
@@ -215,6 +216,7 @@ class Parser(metaclass=_Parser):
TokenType.MEDIUMINT: TokenType.UMEDIUMINT,
TokenType.SMALLINT: TokenType.USMALLINT,
TokenType.TINYINT: TokenType.UTINYINT,
+ TokenType.DECIMAL: TokenType.UDECIMAL,
}
SUBQUERY_PREDICATES = {
@@ -338,6 +340,7 @@ class Parser(metaclass=_Parser):
TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"}
FUNC_TOKENS = {
+ TokenType.COLLATE,
TokenType.COMMAND,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
@@ -590,6 +593,9 @@ class Parser(metaclass=_Parser):
exp.National, this=token.text
),
TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text),
+ TokenType.HEREDOC_STRING: lambda self, token: self.expression(
+ exp.RawString, this=token.text
+ ),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
@@ -666,6 +672,9 @@ class Parser(metaclass=_Parser):
"RETURNS": lambda self: self._parse_returns(),
"ROW": lambda self: self._parse_row(),
"ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty),
+ "SAMPLE": lambda self: self.expression(
+ exp.SampleProperty, this=self._match_text_seq("BY") and self._parse_bitwise()
+ ),
"SET": lambda self: self.expression(exp.SetProperty, multi=False),
"SETTINGS": lambda self: self.expression(
exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item)
@@ -847,8 +856,11 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
+ CLONE_KEYWORDS = {"CLONE", "COPY"}
CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"}
+ OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"}
+
TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE}
WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
@@ -863,6 +875,8 @@ class Parser(metaclass=_Parser):
NULL_TOKENS = {TokenType.NULL}
+ UNNEST_OFFSET_ALIAS_TOKENS = ID_VAR_TOKENS - SET_OPERATIONS
+
STRICT_CAST = True
# A NULL arg in CONCAT yields NULL by default
@@ -880,9 +894,12 @@ class Parser(metaclass=_Parser):
# Whether or not the table sample clause expects CSV syntax
TABLESAMPLE_CSV = False
- # Whether or not the SET command needs a delimiter (e.g. "=") for assignments.
+ # Whether or not the SET command needs a delimiter (e.g. "=") for assignments
SET_REQUIRES_ASSIGNMENT_DELIMITER = True
+ # Whether the TRIM function expects the characters to trim as its first argument
+ TRIM_PATTERN_FIRST = False
+
__slots__ = (
"error_level",
"error_message_context",
@@ -1268,6 +1285,7 @@ class Parser(metaclass=_Parser):
indexes = None
no_schema_binding = None
begin = None
+ end = None
clone = None
def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
@@ -1299,6 +1317,8 @@ class Parser(metaclass=_Parser):
else:
expression = self._parse_statement()
+ end = self._match_text_seq("END")
+
if return_:
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
@@ -1344,7 +1364,8 @@ class Parser(metaclass=_Parser):
shallow = self._match_text_seq("SHALLOW")
- if self._match_text_seq("CLONE"):
+ if self._match_texts(self.CLONE_KEYWORDS):
+ copy = self._prev.text.lower() == "copy"
clone = self._parse_table(schema=True)
when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper()
clone_kind = (
@@ -1361,6 +1382,7 @@ class Parser(metaclass=_Parser):
kind=clone_kind,
shallow=shallow,
expression=clone_expression,
+ copy=copy,
)
return self.expression(
@@ -1376,6 +1398,7 @@ class Parser(metaclass=_Parser):
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
+ end=end,
clone=clone,
)
@@ -2445,21 +2468,32 @@ class Parser(metaclass=_Parser):
kwargs["using"] = self._parse_wrapped_id_vars()
elif not (kind and kind.token_type == TokenType.CROSS):
index = self._index
- joins = self._parse_joins()
+ join = self._parse_join()
- if joins and self._match(TokenType.ON):
+ if join and self._match(TokenType.ON):
kwargs["on"] = self._parse_conjunction()
- elif joins and self._match(TokenType.USING):
+ elif join and self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
else:
- joins = None
+ join = None
self._retreat(index)
- kwargs["this"].set("joins", joins)
+ kwargs["this"].set("joins", [join] if join else None)
comments = [c for token in (method, side, kind) if token for c in token.comments]
return self.expression(exp.Join, comments=comments, **kwargs)
+ def _parse_opclass(self) -> t.Optional[exp.Expression]:
+ this = self._parse_conjunction()
+ if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False):
+ return this
+
+ opclass = self._parse_var(any_token=True)
+ if opclass:
+ return self.expression(exp.Opclass, this=this, expression=opclass)
+
+ return this
+
def _parse_index(
self,
index: t.Optional[exp.Expression] = None,
@@ -2486,7 +2520,7 @@ class Parser(metaclass=_Parser):
using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None
if self._match(TokenType.L_PAREN, advance=False):
- columns = self._parse_wrapped_csv(self._parse_ordered)
+ columns = self._parse_wrapped_csv(lambda: self._parse_ordered(self._parse_opclass))
else:
columns = None
@@ -2677,7 +2711,9 @@ class Parser(metaclass=_Parser):
if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
- offset = self._parse_id_var() or exp.to_identifier("offset")
+ offset = self._parse_id_var(
+ any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS
+ ) or exp.to_identifier("offset")
return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset)
@@ -2715,14 +2751,18 @@ class Parser(metaclass=_Parser):
)
method = self._parse_var(tokens=(TokenType.ROW,))
- self._match(TokenType.L_PAREN)
+ matched_l_paren = self._match(TokenType.L_PAREN)
if self.TABLESAMPLE_CSV:
num = None
expressions = self._parse_csv(self._parse_primary)
else:
expressions = None
- num = self._parse_primary()
+ num = (
+ self._parse_factor()
+ if self._match(TokenType.NUMBER, advance=False)
+ else self._parse_primary()
+ )
if self._match_text_seq("BUCKET"):
bucket_numerator = self._parse_number()
@@ -2737,7 +2777,8 @@ class Parser(metaclass=_Parser):
elif num:
size = num
- self._match(TokenType.R_PAREN)
+ if matched_l_paren:
+ self._match_r_paren()
if self._match(TokenType.L_PAREN):
method = self._parse_var()
@@ -2965,8 +3006,8 @@ class Parser(metaclass=_Parser):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
- def _parse_ordered(self) -> exp.Ordered:
- this = self._parse_conjunction()
+ def _parse_ordered(self, parse_method: t.Optional[t.Callable] = None) -> exp.Ordered:
+ this = parse_method() if parse_method else self._parse_conjunction()
asc = self._match(TokenType.ASC)
desc = self._match(TokenType.DESC) or (asc and False)
@@ -3144,7 +3185,7 @@ class Parser(metaclass=_Parser):
if self._match_text_seq("DISTINCT", "FROM"):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
- return self.expression(klass, this=this, expression=self._parse_expression())
+ return self.expression(klass, this=this, expression=self._parse_conjunction())
expression = self._parse_null() or self._parse_boolean()
if not expression:
@@ -3760,7 +3801,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise())
- def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint:
+ def _parse_generated_as_identity(
+ self,
+ ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
if self._match_text_seq("BY", "DEFAULT"):
on_null = self._match_pair(TokenType.ON, TokenType.NULL)
this = self.expression(
@@ -4382,16 +4425,18 @@ class Parser(metaclass=_Parser):
position = None
collation = None
+ expression = None
if self._match_texts(self.TRIM_TYPES):
position = self._prev.text.upper()
- expression = self._parse_bitwise()
+ this = self._parse_bitwise()
if self._match_set((TokenType.FROM, TokenType.COMMA)):
- this = self._parse_bitwise()
- else:
- this = expression
- expression = None
+ invert_order = self._prev.token_type == TokenType.FROM or self.TRIM_PATTERN_FIRST
+ expression = self._parse_bitwise()
+
+ if invert_order:
+ this, expression = expression, this
if self._match(TokenType.COLLATE):
collation = self._parse_bitwise()
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 4d5f198..080a86b 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -77,6 +77,7 @@ class TokenType(AutoName):
BYTE_STRING = auto()
NATIONAL_STRING = auto()
RAW_STRING = auto()
+ HEREDOC_STRING = auto()
# types
BIT = auto()
@@ -98,6 +99,7 @@ class TokenType(AutoName):
FLOAT = auto()
DOUBLE = auto()
DECIMAL = auto()
+ UDECIMAL = auto()
BIGDECIMAL = auto()
CHAR = auto()
NCHAR = auto()
@@ -418,6 +420,7 @@ class _Tokenizer(type):
**_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS),
**_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS),
**_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS),
+ **_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS),
}
klass._STRING_ESCAPES = set(klass.STRING_ESCAPES)
@@ -484,11 +487,13 @@ class Tokenizer(metaclass=_Tokenizer):
BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
RAW_STRINGS: t.List[str | t.Tuple[str, str]] = []
+ HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = []
IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
IDENTIFIER_ESCAPES = ['"']
QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
STRING_ESCAPES = ["'"]
VAR_SINGLE_TOKENS: t.Set[str] = set()
+ ESCAPE_SEQUENCES: t.Dict[str, str] = {}
# Autofilled
IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False
@@ -997,9 +1002,11 @@ class Tokenizer(metaclass=_Tokenizer):
word = word.upper()
self._add(self.KEYWORDS[word], text=word)
return
+
if self._char in self.SINGLE_TOKENS:
self._add(self.SINGLE_TOKENS[self._char], text=self._char)
return
+
self._scan_var()
def _scan_comment(self, comment_start: str) -> bool:
@@ -1126,6 +1133,10 @@ class Tokenizer(metaclass=_Tokenizer):
base = 16
elif token_type == TokenType.BIT_STRING:
base = 2
+ elif token_type == TokenType.HEREDOC_STRING:
+ self._advance()
+ tag = "" if self._char == end else self._extract_string(end)
+ end = f"{start}{tag}{end}"
else:
return False
@@ -1193,6 +1204,13 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}")
+ if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES:
+ escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek)
+ if escaped_sequence:
+ self._advance(2)
+ text += escaped_sequence
+ continue
+
current = self._current - 1
self._advance(alnum=True)
text += self.sql[current : self._current - 1]