summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:45 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:45 +0000
commita34653eb21369376f0e054dd989311afcb167f5b (patch)
tree5a0280adce195af0be654f79fd99395fd2932c19 /sqlglot/dialects
parentReleasing debian version 18.7.0-1. (diff)
downloadsqlglot-a34653eb21369376f0e054dd989311afcb167f5b.tar.xz
sqlglot-a34653eb21369376f0e054dd989311afcb167f5b.zip
Merging upstream version 18.11.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py30
-rw-r--r--sqlglot/dialects/clickhouse.py30
-rw-r--r--sqlglot/dialects/databricks.py20
-rw-r--r--sqlglot/dialects/dialect.py9
-rw-r--r--sqlglot/dialects/hive.py33
-rw-r--r--sqlglot/dialects/mysql.py107
-rw-r--r--sqlglot/dialects/oracle.py8
-rw-r--r--sqlglot/dialects/postgres.py14
-rw-r--r--sqlglot/dialects/presto.py3
-rw-r--r--sqlglot/dialects/redshift.py2
-rw-r--r--sqlglot/dialects/snowflake.py47
-rw-r--r--sqlglot/dialects/spark.py11
-rw-r--r--sqlglot/dialects/spark2.py26
-rw-r--r--sqlglot/dialects/tsql.py9
14 files changed, 290 insertions, 59 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 1349c56..0d741b5 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -190,6 +190,16 @@ class BigQuery(Dialect):
"%D": "%m/%d/%y",
}
+ ESCAPE_SEQUENCES = {
+ "\\a": "\a",
+ "\\b": "\b",
+ "\\f": "\f",
+ "\\n": "\n",
+ "\\r": "\r",
+ "\\t": "\t",
+ "\\v": "\v",
+ }
+
FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
@@ -212,15 +222,14 @@ class BigQuery(Dialect):
@classmethod
def normalize_identifier(cls, expression: E) -> E:
- # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
- # The following check is essentially a heuristic to detect tables based on whether or
- # not they're qualified.
if isinstance(expression, exp.Identifier):
parent = expression.parent
-
while isinstance(parent, exp.Dot):
parent = parent.parent
+ # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least).
+ # The following check is essentially a heuristic to detect tables based on whether or
+ # not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive.
if (
not isinstance(parent, exp.UserDefinedFunction)
and not (isinstance(parent, exp.Table) and parent.db)
@@ -419,6 +428,7 @@ class BigQuery(Dialect):
RENAME_TABLE_WITH_DB = False
NVL2_SUPPORTED = False
UNNEST_WITH_ORDINALITY = False
+ COLLATE_IS_FUNC = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -520,18 +530,6 @@ class BigQuery(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore
- {
- "\a": "\\a",
- "\b": "\\b",
- "\f": "\\f",
- "\n": "\\n",
- "\r": "\\r",
- "\t": "\\t",
- "\v": "\\v",
- }
- )
-
# from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords
RESERVED_KEYWORDS = {
*generator.Generator.RESERVED_KEYWORDS,
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 7446081..e9d9326 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -2,7 +2,7 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
@@ -21,18 +21,33 @@ def _lower_func(sql: str) -> str:
return sql[:index].lower() + sql[index:]
+def _quantile_sql(self, e):
+ quantile = e.args["quantile"]
+ args = f"({self.sql(e, 'this')})"
+ if isinstance(quantile, exp.Array):
+ func = self.func("quantiles", *quantile)
+ else:
+ func = self.func("quantile", quantile)
+ return func + args
+
+
class ClickHouse(Dialect):
NORMALIZE_FUNCTIONS: bool | str = False
NULL_ORDERING = "nulls_are_last"
STRICT_STRING_CONCAT = True
SUPPORTS_USER_DEFINED_TYPES = False
+ ESCAPE_SEQUENCES = {
+ "\\0": "\0",
+ }
+
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
STRING_ESCAPES = ["'", "\\"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]
+ HEREDOC_STRINGS = ["$"]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -55,6 +70,7 @@ class ClickHouse(Dialect):
"LOWCARDINALITY": TokenType.LOWCARDINALITY,
"MAP": TokenType.MAP,
"NESTED": TokenType.NESTED,
+ "SAMPLE": TokenType.TABLE_SAMPLE,
"TUPLE": TokenType.STRUCT,
"UINT128": TokenType.UINT128,
"UINT16": TokenType.USMALLINT,
@@ -64,6 +80,11 @@ class ClickHouse(Dialect):
"UINT8": TokenType.UTINYINT,
}
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "$": TokenType.HEREDOC_STRING,
+ }
+
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -301,6 +322,7 @@ class ClickHouse(Dialect):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
NVL2_SUPPORTED = False
+ TABLESAMPLE_REQUIRES_PARENS = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -348,6 +370,7 @@ class ClickHouse(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.Select: transforms.preprocess([transforms.eliminate_qualify]),
exp.AnyValue: rename_func("any"),
exp.ApproxDistinct: rename_func("uniq"),
exp.Array: inline_array_sql,
@@ -359,12 +382,13 @@ class ClickHouse(Dialect):
"DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
),
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
+ exp.IsNan: rename_func("isNaN"),
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Pivot: no_pivot_sql,
- exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile"))
- + f"({self.sql(e, 'this')})",
+ exp.Quantile: _quantile_sql,
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
+ exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions),
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 39daad7..a044bc0 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -51,6 +51,26 @@ class Databricks(Spark):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
+ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
+ constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint)
+ kind = expression.args.get("kind")
+ if (
+ constraint
+ and isinstance(kind, exp.DataType)
+ and kind.this in exp.DataType.INTEGER_TYPES
+ ):
+ # only BIGINT generated identity constraints are supported
+ expression = expression.copy()
+ expression.set("kind", exp.DataType.build("bigint"))
+ return super().columndef_sql(expression, sep)
+
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
+ expression = expression.copy()
+ expression.set("this", True) # trigger ALWAYS in super class
+ return super().generatedasidentitycolumnconstraint_sql(expression)
+
class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index ccf04da..bd839af 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -81,6 +81,8 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
+ klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
+
klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
klass.parser_class = getattr(klass, "Parser", Parser)
klass.generator_class = getattr(klass, "Generator", Generator)
@@ -188,6 +190,9 @@ class Dialect(metaclass=_Dialect):
# special syntax cast(x as date format 'yyyy') defaults to time_mapping
FORMAT_MAPPING: t.Dict[str, str] = {}
+ # Mapping of an unescaped escape sequence to the corresponding character
+ ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+
# Columns that are auto-generated by the engine corresponding to this dialect
# Such columns may be excluded from SELECT * queries, for example
PSEUDOCOLUMNS: t.Set[str] = set()
@@ -204,6 +209,8 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
+ INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+
def __eq__(self, other: t.Any) -> bool:
return type(self) == other
@@ -245,7 +252,7 @@ class Dialect(metaclass=_Dialect):
"""
Normalizes an unquoted identifier to either lower or upper case, thus essentially
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
- they will be normalized regardless of being quoted or not.
+ they will be normalized to lowercase regardless of being quoted or not.
"""
if isinstance(expression, exp.Identifier) and (
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index a427870..3f925a7 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -51,6 +51,32 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
+def _create_sql(self, expression: exp.Create) -> str:
+ expression = expression.copy()
+
+ # remove UNIQUE column constraints
+ for constraint in expression.find_all(exp.UniqueColumnConstraint):
+ if constraint.parent:
+ constraint.parent.pop()
+
+ properties = expression.args.get("properties")
+ temporary = any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ )
+
+ # CTAS with temp tables map to CREATE TEMPORARY VIEW
+ kind = expression.args["kind"]
+ if kind.upper() == "TABLE" and temporary:
+ if expression.expression:
+ return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}"
+ else:
+ # CREATE TEMPORARY TABLE may require storage provider
+ expression = self.temporary_storage_provider(expression)
+
+ return create_with_partitions_sql(self, expression)
+
+
def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
@@ -429,7 +455,7 @@ class Hive(Dialect):
if e.args.get("allow_null")
else "NOT NULL",
exp.VarMap: var_map_sql,
- exp.Create: create_with_partitions_sql,
+ exp.Create: _create_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
exp.RegexpExtract: regexp_extract_sql,
@@ -478,8 +504,13 @@ class Hive(Dialect):
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
+ exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
+ # Hive has no temporary storage provider (there are hive settings though)
+ return expression
+
def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
parent = expression.parent
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 554241d..59a0a2a 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
+def _str_to_date_sql(
+ self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
@@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(
+ kind: str,
+) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
+ def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D
return func
+def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
+ time_format = expression.args.get("format")
+ if time_format:
+ return _str_to_date_sql(self, expression)
+ return f"DATE({self.sql(expression, 'this')})"
+
+
+def _remove_ts_or_ds_to_date(
+ to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None,
+ args: t.Tuple[str, ...] = ("this",),
+) -> t.Callable[[MySQL.Generator, exp.Func], str]:
+ def func(self: MySQL.Generator, expression: exp.Func) -> str:
+ expression = expression.copy()
+
+ for arg_key in args:
+ arg = expression.args.get(arg_key)
+ if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
+ expression.set(arg_key, arg.this)
+
+ return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression)
+
+ return func
+
+
class MySQL(Dialect):
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
IDENTIFIERS_CAN_START_WITH_DIGIT = True
@@ -233,6 +261,7 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
@@ -240,14 +269,33 @@ class MySQL(Dialect):
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MONTHNAME": lambda args: exp.TimeToStr(
- this=seq_get(args, 0),
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
+ "TO_DAYS": lambda args: exp.paren(
+ exp.DateDiff(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")),
+ unit=exp.var("DAY"),
+ )
+ + 1
+ ),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "WEEK": lambda args: exp.Week(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
+ ),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ "CHAR": lambda self: self._parse_chr(),
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -531,6 +579,18 @@ class MySQL(Dialect):
return super()._parse_type(parse_interval=parse_interval)
+ def _parse_chr(self) -> t.Optional[exp.Expression]:
+ expressions = self._parse_csv(self._parse_conjunction)
+ kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)}
+
+ if len(expressions) > 1:
+ kwargs["expressions"] = expressions[1:]
+
+ if self._match(TokenType.USING):
+ kwargs["charset"] = self._parse_var()
+
+ return self.expression(exp.Chr, **kwargs)
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@@ -544,25 +604,33 @@ class MySQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
- exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
- exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateDiff: _remove_ts_or_ds_to_date(
+ lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
+ ),
+ exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _date_add_sql("SUB"),
+ exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
- exp.DayOfMonth: rename_func("DAYOFMONTH"),
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.Day: _remove_ts_or_ds_to_date(),
+ exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
+ exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
+ exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@@ -573,10 +641,16 @@ class MySQL(Dialect):
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
- exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
+ exp.TimeToStr: _remove_ts_or_ds_to_date(
+ lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
+ ),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
+ exp.TsOrDsAdd: _date_add_sql("ADD"),
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Week: _remove_ts_or_ds_to_date(),
+ exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
+ exp.Year: _remove_ts_or_ds_to_date(),
}
UNSIGNED_TYPE_MAPPING = {
@@ -585,6 +659,7 @@ class MySQL(Dialect):
exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
exp.DataType.Type.USMALLINT: "SMALLINT",
exp.DataType.Type.UTINYINT: "TINYINT",
+ exp.DataType.Type.UDECIMAL: "DECIMAL",
}
TIMESTAMP_TYPE_MAPPING = {
@@ -717,3 +792,9 @@ class MySQL(Dialect):
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
+
+ def chr_sql(self, expression: exp.Chr) -> str:
+ this = self.expressions(sqls=[expression.this] + expression.expressions)
+ charset = expression.args.get("charset")
+ using = f" USING {self.sql(charset)}" if charset else ""
+ return f"CHAR({this}{using})"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 0a4926d..6a007ab 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -153,6 +153,7 @@ class Oracle(Dialect):
JOIN_HINTS = False
TABLE_HINTS = False
COLUMN_JOIN_MARKS_SUPPORTED = True
+ DATA_TYPE_SPECIFIERS_ALLOWED = True
LIMIT_FETCH = "FETCH"
@@ -179,7 +180,12 @@ class Oracle(Dialect):
),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.ILike: no_ilike_sql,
- exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
+ exp.Select: transforms.preprocess(
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_qualify,
+ ]
+ ),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
exp.Substring: rename_func("SUBSTR"),
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 342fd95..008727c 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -22,6 +22,7 @@ from sqlglot.dialects.dialect import (
rename_func,
simplify_literal,
str_position_sql,
+ struct_extract_sql,
timestamptrunc_sql,
timestrtotime_sql,
trim_sql,
@@ -248,11 +249,10 @@ class Postgres(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- QUOTES = ["'", "$$"]
-
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
+ HEREDOC_STRINGS = ["$"]
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -296,7 +296,7 @@ class Postgres(Dialect):
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
- "$": TokenType.PARAMETER,
+ "$": TokenType.HEREDOC_STRING,
}
VAR_SINGLE_TOKENS = {"$"}
@@ -420,9 +420,15 @@ class Postgres(Dialect):
exp.Pow: lambda self, e: self.binary(e, "^"),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
- exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]),
+ exp.Select: transforms.preprocess(
+ [
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
+ ),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 0d8d4ab..e5cfa1c 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -309,6 +309,9 @@ class Presto(Dialect):
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
exp.Group: transforms.preprocess([transforms.unalias_group]),
+ exp.GroupConcat: lambda self, e: self.func(
+ "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
+ ),
exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 2145844..88e4448 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -83,7 +83,7 @@ class Redshift(Postgres):
class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
- STRING_ESCAPES = ["\\"]
+ STRING_ESCAPES = ["\\", "'"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 5c49331..fc3e0fa 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -239,6 +239,8 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
@@ -318,6 +320,43 @@ class Snowflake(Dialect):
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
}
+ STAGED_FILE_SINGLE_TOKENS = {
+ TokenType.DOT,
+ TokenType.MOD,
+ TokenType.SLASH,
+ }
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
+ # https://docs.snowflake.com/en/user-guide/querying-stage
+ table: t.Optional[exp.Expression] = None
+ if self._match_text_seq("@"):
+ table_name = "@"
+ while True:
+ self._advance()
+ table_name += self._prev.text
+ if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False):
+ break
+ while self._match_set(self.STAGED_FILE_SINGLE_TOKENS):
+ table_name += self._prev.text
+
+ table = exp.var(table_name)
+ elif self._match(TokenType.STRING, advance=False):
+ table = self._parse_string()
+
+ if table:
+ file_format = None
+ pattern = None
+
+ if self._match_text_seq("(", "FILE_FORMAT", "=>"):
+ file_format = self._parse_string() or super()._parse_table_parts()
+ if self._match_text_seq(",", "PATTERN", "=>"):
+ pattern = self._parse_string()
+ self._match_r_paren()
+
+ return self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
+
+ return super()._parse_table_parts(schema=schema)
+
def _parse_id_var(
self,
any_token: bool = True,
@@ -394,6 +433,8 @@ class Snowflake(Dialect):
TABLE_HINTS = False
QUERY_HINTS = False
AGGREGATE_FILTER_SUPPORTED = False
+ SUPPORTS_TABLE_COPY = False
+ COLLATE_IS_FUNC = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -423,6 +464,12 @@ class Snowflake(Dialect):
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
+ exp.PercentileCont: transforms.preprocess(
+ [transforms.add_within_group_for_percentiles]
+ ),
+ exp.PercentileDisc: transforms.preprocess(
+ [transforms.add_within_group_for_percentiles]
+ ),
exp.RegexpILike: _regexpilike_sql,
exp.Select: transforms.preprocess(
[
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 9d4a1ab..2eaa2ae 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -54,6 +54,14 @@ class Spark(Spark2):
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("ANY_VALUE")
+ def _parse_generated_as_identity(
+ self,
+ ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
+ this = super()._parse_generated_as_identity()
+ if this.expression:
+ return self.expression(exp.ComputedColumnConstraint, this=this.expression)
+ return this
+
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
@@ -73,6 +81,9 @@ class Spark(Spark2):
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
+ def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
+ return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
+
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
return self.function_fallback_sql(expression)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 3dc9838..4130375 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -5,7 +5,6 @@ import typing as t
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
binary_from_function,
- create_with_partitions_sql,
format_time_lambda,
is_parse_json,
move_insert_cte_sql,
@@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
-def _create_sql(self: Spark2.Generator, e: exp.Create) -> str:
- kind = e.args["kind"]
- properties = e.args.get("properties")
-
- if (
- kind.upper() == "TABLE"
- and e.expression
- and any(
- isinstance(prop, exp.TemporaryProperty)
- for prop in (properties.expressions if properties else [])
- )
- ):
- return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
- return create_with_partitions_sql(self, e)
-
-
def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str:
keys = expression.args.get("keys")
values = expression.args.get("values")
@@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
class Spark2(Hive):
class Parser(Hive.Parser):
+ TRIM_PATTERN_FIRST = True
+
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
"AGGREGATE": exp.Reduce.from_arg_list,
@@ -192,7 +177,6 @@ class Spark2(Hive):
exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.Create: _create_sql,
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
@@ -236,6 +220,12 @@ class Spark2(Hive):
WRAP_DERIVED_VALUES = False
CREATE_FUNCTION_RETURN_AS = False
+ def temporary_storage_provider(self, expression: exp.Create) -> exp.Create:
+ # spark2, spark, Databricks require a storage provider for temporary tables
+ provider = exp.FileFormatProperty(this=exp.Literal.string("parquet"))
+ expression.args["properties"].append("expressions", provider)
+ return expression
+
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if is_parse_json(expression.this):
schema = f"'{self.sql(expression, 'to')}'"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index fa62e78..6aa49e4 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
parse_date_delta,
rename_func,
timestrtotime_sql,
+ ts_or_ds_to_date_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -590,6 +591,7 @@ class TSQL(Dialect):
NVL2_SUPPORTED = False
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
+ COMPUTED_COLUMN_WITH_TYPE = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -619,7 +621,11 @@ class TSQL(Dialect):
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
@@ -630,6 +636,7 @@ class TSQL(Dialect):
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
TRANSFORMS.pop(exp.ReturnsProperty)