summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:11:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:12:02 +0000
commit8d36f5966675e23bee7026ba37ae0647fbf47300 (patch)
treedf4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/dialects
parentReleasing debian version 22.2.0-1. (diff)
downloadsqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz
sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/__init__.py2
-rw-r--r--sqlglot/dialects/athena.py37
-rw-r--r--sqlglot/dialects/bigquery.py54
-rw-r--r--sqlglot/dialects/clickhouse.py77
-rw-r--r--sqlglot/dialects/dialect.py115
-rw-r--r--sqlglot/dialects/doris.py14
-rw-r--r--sqlglot/dialects/drill.py16
-rw-r--r--sqlglot/dialects/duckdb.py42
-rw-r--r--sqlglot/dialects/hive.py7
-rw-r--r--sqlglot/dialects/mysql.py82
-rw-r--r--sqlglot/dialects/oracle.py2
-rw-r--r--sqlglot/dialects/postgres.py12
-rw-r--r--sqlglot/dialects/presto.py40
-rw-r--r--sqlglot/dialects/prql.py109
-rw-r--r--sqlglot/dialects/redshift.py22
-rw-r--r--sqlglot/dialects/snowflake.py201
-rw-r--r--sqlglot/dialects/spark.py11
-rw-r--r--sqlglot/dialects/spark2.py9
-rw-r--r--sqlglot/dialects/sqlite.py12
-rw-r--r--sqlglot/dialects/starrocks.py7
-rw-r--r--sqlglot/dialects/tableau.py2
-rw-r--r--sqlglot/dialects/teradata.py9
-rw-r--r--sqlglot/dialects/trino.py1
-rw-r--r--sqlglot/dialects/tsql.py24
24 files changed, 685 insertions, 222 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 82552c9..29c6580 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -61,6 +61,7 @@ dialect implementations in order to understand how their various components can
----
"""
+from sqlglot.dialects.athena import Athena
from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
@@ -73,6 +74,7 @@ from sqlglot.dialects.mysql import MySQL
from sqlglot.dialects.oracle import Oracle
from sqlglot.dialects.postgres import Postgres
from sqlglot.dialects.presto import Presto
+from sqlglot.dialects.prql import PRQL
from sqlglot.dialects.redshift import Redshift
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.spark import Spark
diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py
new file mode 100644
index 0000000..f2deec8
--- /dev/null
+++ b/sqlglot/dialects/athena.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from sqlglot import exp
+from sqlglot.dialects.trino import Trino
+from sqlglot.tokens import TokenType
+
+
+class Athena(Trino):
+ class Parser(Trino.Parser):
+ STATEMENT_PARSERS = {
+ **Trino.Parser.STATEMENT_PARSERS,
+ TokenType.USING: lambda self: self._parse_as_command(self._prev),
+ }
+
+ class Generator(Trino.Generator):
+ PROPERTIES_LOCATION = {
+ **Trino.Generator.PROPERTIES_LOCATION,
+ exp.LocationProperty: exp.Properties.Location.POST_SCHEMA,
+ }
+
+ TYPE_MAPPING = {
+ **Trino.Generator.TYPE_MAPPING,
+ exp.DataType.Type.TEXT: "STRING",
+ }
+
+ TRANSFORMS = {
+ **Trino.Generator.TRANSFORMS,
+ exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}",
+ }
+
+ def property_sql(self, expression: exp.Property) -> str:
+ return (
+ f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}"
+ )
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("TBLPROPERTIES"))
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5bfc3ea..2167ba2 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -24,6 +24,7 @@ from sqlglot.dialects.dialect import (
rename_func,
timestrtotime_sql,
ts_or_ds_add_cast,
+ unit_to_var,
)
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
@@ -41,14 +42,22 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
structs = []
alias = expression.args.get("alias")
for tup in expression.find_all(exp.Tuple):
- field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
+ field_aliases = (
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(tup.expressions)))
+ )
expressions = [
exp.PropertyEQ(this=exp.to_identifier(name), expression=fld)
for name, fld in zip(field_aliases, tup.expressions)
]
structs.append(exp.Struct(expressions=expressions))
- return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
+ # Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression
+ alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None
+ return self.unnest_sql(
+ exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only)
+ )
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
@@ -190,7 +199,7 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
- unit = expression.args.get("unit") or "DAY"
+ unit = unit_to_var(expression)
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
@@ -238,16 +247,6 @@ class BigQuery(Dialect):
"%E6S": "%S.%f",
}
- ESCAPE_SEQUENCES = {
- "\\a": "\a",
- "\\b": "\b",
- "\\f": "\f",
- "\\n": "\n",
- "\\r": "\r",
- "\\t": "\t",
- "\\v": "\v",
- }
-
FORMAT_MAPPING = {
"DD": "%d",
"MM": "%m",
@@ -315,6 +314,7 @@ class BigQuery(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
+ "DATETIME": TokenType.TIMESTAMP,
"DECLARE": TokenType.COMMAND,
"ELSEIF": TokenType.COMMAND,
"EXCEPTION": TokenType.COMMAND,
@@ -486,14 +486,14 @@ class BigQuery(Dialect):
table.set("db", exp.Identifier(this=parts[0]))
table.set("this", exp.Identifier(this=parts[1]))
- if isinstance(table.this, exp.Identifier) and "." in table.name:
+ if any("." in p.name for p in table.parts):
catalog, db, this, *rest = (
- t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True))
- for x in split_num_words(table.name, ".", 3)
+ exp.to_identifier(p, quoted=True)
+ for p in split_num_words(".".join(p.name for p in table.parts), ".", 3)
)
if rest and this:
- this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))
+ this = exp.Dot.build([this, *rest]) # type: ignore
table = exp.Table(this=this, db=db, catalog=catalog)
table.meta["quoted_table"] = True
@@ -527,7 +527,9 @@ class BigQuery(Dialect):
return json_object
- def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ def _parse_bracket(
+ self, this: t.Optional[exp.Expression] = None
+ ) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)
if this is bracket:
@@ -566,6 +568,7 @@ class BigQuery(Dialect):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
CAN_IMPLEMENT_ARRAY_ANY = True
+ SUPPORTS_TO_NUMBER = False
NAMED_PLACEHOLDER_TOKEN = "@"
TRANSFORMS = {
@@ -588,7 +591,7 @@ class BigQuery(Dialect):
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", e.this, e.expression, e.unit or "DAY"
+ "DATE_DIFF", e.this, e.expression, unit_to_var(e)
),
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
@@ -607,6 +610,7 @@ class BigQuery(Dialect):
exp.IntDiv: rename_func("DIV"),
exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Max: max_or_greatest,
+ exp.Mod: rename_func("MOD"),
exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
@@ -847,10 +851,10 @@ class BigQuery(Dialect):
return inline_array_sql(self, expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
- this = self.sql(expression, "this")
+ this = expression.this
expressions = expression.expressions
- if len(expressions) == 1:
+ if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT):
arg = expressions[0]
if arg.type is None:
from sqlglot.optimizer.annotate_types import annotate_types
@@ -858,10 +862,10 @@ class BigQuery(Dialect):
arg = annotate_types(arg)
if arg.type and arg.type.this in exp.DataType.TEXT_TYPES:
- # BQ doesn't support bracket syntax with string values
- return f"{this}.{arg.name}"
+ # BQ doesn't support bracket syntax with string values for structs
+ return f"{self.sql(this)}.{arg.name}"
- expressions_sql = ", ".join(self.sql(e) for e in expressions)
+ expressions_sql = self.expressions(expression, flat=True)
offset = expression.args.get("offset")
if offset == 0:
@@ -874,7 +878,7 @@ class BigQuery(Dialect):
if expression.args.get("safe"):
expressions_sql = f"SAFE_{expressions_sql}"
- return f"{this}[{expressions_sql}]"
+ return f"{self.sql(this)}[{expressions_sql}]"
def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(expression)
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 90167f6..631dc30 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -15,7 +15,6 @@ from sqlglot.dialects.dialect import (
rename_func,
var_map_sql,
)
-from sqlglot.errors import ParseError
from sqlglot.helper import is_int, seq_get
from sqlglot.tokens import Token, TokenType
@@ -49,8 +48,9 @@ class ClickHouse(Dialect):
NULL_ORDERING = "nulls_are_last"
SUPPORTS_USER_DEFINED_TYPES = False
SAFE_DIVISION = True
+ LOG_BASE_FIRST: t.Optional[bool] = None
- ESCAPE_SEQUENCES = {
+ UNESCAPED_SEQUENCES = {
"\\0": "\0",
}
@@ -105,6 +105,7 @@ class ClickHouse(Dialect):
# * select x from t1 union all select x from t2 limit 1;
# * select x from t1 union all (select x from t2 limit 1);
MODIFIERS_ATTACHED_TO_UNION = False
+ INTERVAL_SPANS = False
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -260,6 +261,11 @@ class ClickHouse(Dialect):
"ArgMax",
]
+ FUNC_TOKENS = {
+ *parser.Parser.FUNC_TOKENS,
+ TokenType.SET,
+ }
+
AGG_FUNC_MAPPING = (
lambda functions, suffixes: {
f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions
@@ -305,6 +311,10 @@ class ClickHouse(Dialect):
TokenType.SETTINGS,
}
+ ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - {
+ TokenType.FORMAT,
+ }
+
LOG_DEFAULTS_TO_LN = True
QUERY_MODIFIER_PARSERS = {
@@ -316,6 +326,17 @@ class ClickHouse(Dialect):
TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()),
}
+ CONSTRAINT_PARSERS = {
+ **parser.Parser.CONSTRAINT_PARSERS,
+ "INDEX": lambda self: self._parse_index_constraint(),
+ "CODEC": lambda self: self._parse_compress(),
+ }
+
+ SCHEMA_UNNAMED_CONSTRAINTS = {
+ *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
+ "INDEX",
+ }
+
def _parse_conjunction(self) -> t.Optional[exp.Expression]:
this = super()._parse_conjunction()
@@ -381,21 +402,20 @@ class ClickHouse(Dialect):
# https://clickhouse.com/docs/en/sql-reference/statements/select/with/
def _parse_cte(self) -> exp.CTE:
- index = self._index
- try:
- # WITH <identifier> AS <subquery expression>
- return super()._parse_cte()
- except ParseError:
- # WITH <expression> AS <identifier>
- self._retreat(index)
+ # WITH <identifier> AS <subquery expression>
+ cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte)
- return self.expression(
+ if not cte:
+ # WITH <expression> AS <identifier>
+ cte = self.expression(
exp.CTE,
this=self._parse_conjunction(),
alias=self._parse_table_alias(),
scalar=True,
)
+ return cte
+
def _parse_join_parts(
self,
) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
@@ -508,6 +528,27 @@ class ClickHouse(Dialect):
self._retreat(index)
return None
+ def _parse_index_constraint(
+ self, kind: t.Optional[str] = None
+ ) -> exp.IndexColumnConstraint:
+ # INDEX name1 expr TYPE type1(args) GRANULARITY value
+ this = self._parse_id_var()
+ expression = self._parse_conjunction()
+
+ index_type = self._match_text_seq("TYPE") and (
+ self._parse_function() or self._parse_var()
+ )
+
+ granularity = self._match_text_seq("GRANULARITY") and self._parse_term()
+
+ return self.expression(
+ exp.IndexColumnConstraint,
+ this=this,
+ expression=expression,
+ index_type=index_type,
+ granularity=granularity,
+ )
+
class Generator(generator.Generator):
QUERY_HINTS = False
STRUCT_DELIMITER = ("(", ")")
@@ -517,6 +558,7 @@ class ClickHouse(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
+ SUPPORTS_TO_NUMBER = False
STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
@@ -585,6 +627,9 @@ class ClickHouse(Dialect):
exp.Array: inline_array_sql,
exp.CastToStrType: rename_func("CAST"),
exp.CountIf: rename_func("countIf"),
+ exp.CompressColumnConstraint: lambda self,
+ e: f"CODEC({self.expressions(e, key='this', flat=True)})",
+ exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}",
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: date_delta_sql("DATE_ADD"),
exp.DateDiff: date_delta_sql("DATE_DIFF"),
@@ -737,3 +782,15 @@ class ClickHouse(Dialect):
def prewhere_sql(self, expression: exp.PreWhere) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('PREWHERE')}{self.sep()}{this}"
+
+ def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ expr = self.sql(expression, "expression")
+ expr = f" {expr}" if expr else ""
+ index_type = self.sql(expression, "index_type")
+ index_type = f" TYPE {index_type}" if index_type else ""
+ granularity = self.sql(expression, "granularity")
+ granularity = f" GRANULARITY {granularity}" if granularity else ""
+
+ return f"INDEX{this}{expr}{index_type}{granularity}"
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 599505c..81057c2 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -31,6 +31,7 @@ class Dialects(str, Enum):
DIALECT = ""
+ ATHENA = "athena"
BIGQUERY = "bigquery"
CLICKHOUSE = "clickhouse"
DATABRICKS = "databricks"
@@ -42,6 +43,7 @@ class Dialects(str, Enum):
ORACLE = "oracle"
POSTGRES = "postgres"
PRESTO = "presto"
+ PRQL = "prql"
REDSHIFT = "redshift"
SNOWFLAKE = "snowflake"
SPARK = "spark"
@@ -108,11 +110,18 @@ class _Dialect(type):
klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()}
klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING)
- klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()}
+ base = seq_get(bases, 0)
+ base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),)
+ base_parser = (getattr(base, "parser_class", Parser),)
+ base_generator = (getattr(base, "generator_class", Generator),)
- klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
- klass.parser_class = getattr(klass, "Parser", Parser)
- klass.generator_class = getattr(klass, "Generator", Generator)
+ klass.tokenizer_class = klass.__dict__.get(
+ "Tokenizer", type("Tokenizer", base_tokenizer, {})
+ )
+ klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {}))
+ klass.generator_class = klass.__dict__.get(
+ "Generator", type("Generator", base_generator, {})
+ )
klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0]
klass.IDENTIFIER_START, klass.IDENTIFIER_END = list(
@@ -134,9 +143,31 @@ class _Dialect(type):
klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING)
klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING)
+ if "\\" in klass.tokenizer_class.STRING_ESCAPES:
+ klass.UNESCAPED_SEQUENCES = {
+ "\\a": "\a",
+ "\\b": "\b",
+ "\\f": "\f",
+ "\\n": "\n",
+ "\\r": "\r",
+ "\\t": "\t",
+ "\\v": "\v",
+ "\\\\": "\\",
+ **klass.UNESCAPED_SEQUENCES,
+ }
+
+ klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}
+
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()
+ if enum not in ("", "databricks", "hive", "spark", "spark2"):
+ modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
+ for modifier in ("cluster", "distribute", "sort"):
+ modifier_transforms.pop(modifier, None)
+
+ klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms
+
if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
@@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect):
False: Disables function name normalization.
"""
- LOG_BASE_FIRST = True
- """Whether the base comes first in the `LOG` function."""
+ LOG_BASE_FIRST: t.Optional[bool] = True
+ """
+ Whether the base comes first in the `LOG` function.
+ Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`)
+ """
NULL_ORDERING = "nulls_are_small"
"""
@@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect):
If empty, the corresponding trie will be constructed off of `TIME_MAPPING`.
"""
- ESCAPE_SEQUENCES: t.Dict[str, str] = {}
- """Mapping of an unescaped escape sequence to the corresponding character."""
+ UNESCAPED_SEQUENCES: t.Dict[str, str] = {}
+ """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`)."""
PSEUDOCOLUMNS: t.Set[str] = set()
"""
@@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect):
INVERSE_TIME_MAPPING: t.Dict[str, str] = {}
INVERSE_TIME_TRIE: t.Dict = {}
- INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {}
+ ESCAPED_SEQUENCES: t.Dict[str, str] = {}
# Delimiters for string literals and identifiers
QUOTE_START = "'"
@@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) ->
return ""
-def str_position_sql(self: Generator, expression: exp.StrPosition) -> str:
+def str_position_sql(
+ self: Generator, expression: exp.StrPosition, generate_instance: bool = False
+) -> str:
this = self.sql(expression, "this")
substr = self.sql(expression, "substr")
position = self.sql(expression, "position")
+ instance = expression.args.get("instance") if generate_instance else None
+ position_offset = ""
+
if position:
- return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
- return f"STRPOS({this}, {substr})"
+ # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects
+ this = self.func("SUBSTR", this, position)
+ position_offset = f" + {position} - 1"
+
+ return self.func("STRPOS", this, substr, instance) + position_offset
def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str:
@@ -689,9 +731,7 @@ def build_date_delta_with_interval(
if expression and expression.is_string:
expression = exp.Literal.number(expression.this)
- return expression_class(
- this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit"))
- )
+ return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval))
return _builder
@@ -710,18 +750,14 @@ def date_add_interval_sql(
) -> t.Callable[[Generator, exp.Expression], str]:
def func(self: Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
- unit = expression.args.get("unit")
- unit = exp.var(unit.name.upper() if unit else "DAY")
- interval = exp.Interval(this=expression.expression, unit=unit)
+ interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression))
return f"{data_type}_{kind}({this}, {self.sql(interval)})"
return func
def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
- return self.func(
- "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this
- )
+ return self.func("DATE_TRUNC", unit_to_str(expression), expression.this)
def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
@@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return self.func(
name,
- exp.var(expression.text("unit").upper() or "DAY"),
+ unit_to_var(expression),
expression.expression,
expression.this,
)
@@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE
return _delta_sql
+def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
+ unit = expression.args.get("unit")
+
+ if isinstance(unit, exp.Placeholder):
+ return unit
+ if unit:
+ return exp.Literal.string(unit.name)
+ return exp.Literal.string(default) if default else None
+
+
+def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]:
+ unit = expression.args.get("unit")
+
+ if isinstance(unit, (exp.Var, exp.Placeholder)):
+ return unit
+ return exp.Var(this=default) if default else None
+
+
def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
trunc_curr_date = exp.func("date_trunc", "month", expression.this)
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
@@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
def build_json_extract_path(
- expr_type: t.Type[F], zero_based_indexing: bool = True
+ expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False
) -> t.Callable[[t.List], F]:
def _builder(args: t.List) -> F:
segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
@@ -1018,7 +1072,11 @@ def build_json_extract_path(
# This is done to avoid failing in the expression validator due to the arg count
del args[2:]
- return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments))
+ return expr_type(
+ this=seq_get(args, 0),
+ expression=exp.JSONPath(expressions=segments),
+ only_json_types=arrow_req_json_type,
+ )
return _builder
@@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
unnest = exp.Unnest(expressions=[expression.this])
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
return self.sql(exp.Array(expressions=[filtered]))
+
+
+def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
+ return self.func(
+ "TO_NUMBER",
+ expression.this,
+ expression.args.get("format"),
+ expression.args.get("nlsparam"),
+ )
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 9a84848..f4ec0e5 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import (
build_timestamp_trunc,
rename_func,
time_format,
+ unit_to_str,
)
from sqlglot.dialects.mysql import MySQL
@@ -27,7 +28,7 @@ class Doris(MySQL):
}
class Generator(MySQL.Generator):
- CAST_MAPPING = {}
+ LAST_DAY_SUPPORTS_DATE_PART = False
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
@@ -36,8 +37,7 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
- LAST_DAY_SUPPORTS_DATE_PART = False
-
+ CAST_MAPPING = {}
TIMESTAMP_FUNC_TYPES = set()
TRANSFORMS = {
@@ -49,9 +49,7 @@ class Doris(MySQL):
exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayUniqueAgg: rename_func("COLLECT_SET"),
exp.CurrentTimestamp: lambda self, _: self.func("NOW"),
- exp.DateTrunc: lambda self, e: self.func(
- "DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
- ),
+ exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.Map: rename_func("ARRAY_MAP"),
@@ -63,9 +61,7 @@ class Doris(MySQL):
exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression),
exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this),
exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", e.this, "'" + e.text("unit") + "'"
- ),
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)),
exp.UnixToStr: lambda self, e: self.func(
"FROM_UNIXTIME", e.this, time_format("doris")(self, e)
),
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index c1f6afa..0a00d92 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -1,6 +1,5 @@
from __future__ import annotations
-import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
@@ -12,18 +11,10 @@ from sqlglot.dialects.dialect import (
str_position_sql,
timestrtotime_sql,
)
+from sqlglot.dialects.mysql import date_add_sql
from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by
-def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
- this = self.sql(expression, "this")
- unit = exp.var(expression.text("unit").upper() or "DAY")
- return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit))
-
- return func
-
-
def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
@@ -84,7 +75,6 @@ class Drill(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"),
}
@@ -124,9 +114,9 @@ class Drill(Dialect):
exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
exp.ArraySize: rename_func("REPEATED_COUNT"),
exp.Create: preprocess([move_schema_columns_to_partitioned_by]),
- exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateAdd: date_add_sql("ADD"),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _date_add_sql("SUB"),
+ exp.DateSub: date_add_sql("SUB"),
exp.DateToDi: lambda self,
e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)",
exp.DiToDate: lambda self,
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f74dc97..6a1d07a 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -26,6 +26,7 @@ from sqlglot.dialects.dialect import (
str_to_time_sql,
timestamptrunc_sql,
timestrtotime_sql,
+ unit_to_var,
)
from sqlglot.helper import flatten, seq_get
from sqlglot.tokens import TokenType
@@ -33,15 +34,16 @@ from sqlglot.tokens import TokenType
def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str:
this = self.sql(expression, "this")
- unit = self.sql(expression, "unit").strip("'") or "DAY"
- interval = self.sql(exp.Interval(this=expression.expression, unit=unit))
+ interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression)))
return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}"
-def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_delta_sql(
+ self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd
+) -> str:
this = self.sql(expression, "this")
- unit = self.sql(expression, "unit").strip("'") or "DAY"
- op = "+" if isinstance(expression, exp.DateAdd) else "-"
+ unit = unit_to_var(expression)
+ op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-"
return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}"
@@ -186,6 +188,11 @@ class DuckDB(Dialect):
return super().to_json_path(path)
class Tokenizer(tokens.Tokenizer):
+ HEREDOC_STRINGS = ["$"]
+
+ HEREDOC_TAG_IS_IDENTIFIER = True
+ HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"//": TokenType.DIV,
@@ -199,6 +206,7 @@ class DuckDB(Dialect):
"LOGICAL": TokenType.BOOLEAN,
"ONLY": TokenType.ONLY,
"PIVOT_WIDER": TokenType.PIVOT,
+ "POSITIONAL": TokenType.POSITIONAL,
"SIGNED": TokenType.INT,
"STRING": TokenType.VARCHAR,
"UBIGINT": TokenType.UBIGINT,
@@ -227,8 +235,8 @@ class DuckDB(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAY_HAS": exp.ArrayContains.from_arg_list,
- "ARRAY_SORT": exp.SortArray.from_arg_list,
"ARRAY_REVERSE_SORT": _build_sort_array_desc,
+ "ARRAY_SORT": exp.SortArray.from_arg_list,
"DATEDIFF": _build_date_diff,
"DATE_DIFF": _build_date_diff,
"DATE_TRUNC": date_trunc_to_time,
@@ -285,6 +293,11 @@ class DuckDB(Dialect):
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("DECODE")
+ NO_PAREN_FUNCTION_PARSERS = {
+ **parser.Parser.NO_PAREN_FUNCTION_PARSERS,
+ "MAP": lambda self: self._parse_map(),
+ }
+
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
TokenType.ANTI,
@@ -299,6 +312,13 @@ class DuckDB(Dialect):
),
}
+ def _parse_map(self) -> exp.ToMap | exp.Map:
+ if self._match(TokenType.L_BRACE, advance=False):
+ return self.expression(exp.ToMap, this=self._parse_bracket())
+
+ args = self._parse_wrapped_csv(self._parse_conjunction)
+ return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))
+
def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
@@ -345,6 +365,7 @@ class DuckDB(Dialect):
SUPPORTS_CREATE_TABLE_LIKE = False
MULTI_ARG_DISTINCT = False
CAN_IMPLEMENT_ARRAY_ANY = True
+ SUPPORTS_TO_NUMBER = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -425,6 +446,7 @@ class DuckDB(Dialect):
"EPOCH", self.func("STRPTIME", e.this, self.format_time(e))
),
exp.Struct: _struct_sql,
+ exp.TimeAdd: _date_delta_sql,
exp.Timestamp: no_timestamp_sql,
exp.TimestampDiff: lambda self, e: self.func(
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
@@ -478,7 +500,7 @@ class DuckDB(Dialect):
STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}
- UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren)
+ UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)
# DuckDB doesn't generally support CREATE TABLE .. properties
# https://duckdb.org/docs/sql/statements/create_table.html
@@ -569,3 +591,9 @@ class DuckDB(Dialect):
return rename_func("RANGE")(self, expression)
return super().generateseries_sql(expression)
+
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ if isinstance(expression.this, exp.Array):
+ expression.this.replace(exp.paren(expression.this))
+
+ return super().bracket_sql(expression)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 55a9254..cc7debb 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -319,7 +319,9 @@ class Hive(Dialect):
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
- "UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True),
+ "UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
+ args or [exp.CurrentTimestamp()]
+ ),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
@@ -431,6 +433,7 @@ class Hive(Dialect):
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
+ SUPPORTS_TO_NUMBER = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
@@ -472,7 +475,7 @@ class Hive(Dialect):
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayConcat: rename_func("CONCAT"),
- exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
+ exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 6ebae1e..1d53346 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import (
build_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
+ unit_to_var,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(
+def date_add_sql(
kind: str,
-) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
- def func(self: MySQL.Generator, expression: exp.Expression) -> str:
- this = self.sql(expression, "this")
- unit = expression.text("unit").upper() or "DAY"
- return (
- f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+) -> t.Callable[[generator.Generator, exp.Expression], str]:
+ def func(self: generator.Generator, expression: exp.Expression) -> str:
+ return self.func(
+ f"DATE_{kind}",
+ expression.this,
+ exp.Interval(this=expression.expression, unit=unit_to_var(expression)),
)
return func
@@ -291,6 +292,7 @@ class MySQL(Dialect):
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
+ "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"),
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MAKETIME": exp.TimeFromParts.from_arg_list,
@@ -319,11 +321,7 @@ class MySQL(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
"CHAR": lambda self: self._parse_chr(),
- "GROUP_CONCAT": lambda self: self.expression(
- exp.GroupConcat,
- this=self._parse_lambda(),
- separator=self._match(TokenType.SEPARATOR) and self._parse_field(),
- ),
+ "GROUP_CONCAT": lambda self: self._parse_group_concat(),
# https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
"VALUES": lambda self: self.expression(
exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()]
@@ -412,6 +410,11 @@ class MySQL(Dialect):
"SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"),
}
+ ALTER_PARSERS = {
+ **parser.Parser.ALTER_PARSERS,
+ "MODIFY": lambda self: self._parse_alter_table_alter(),
+ }
+
SCHEMA_UNNAMED_CONSTRAINTS = {
*parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS,
"FULLTEXT",
@@ -458,7 +461,7 @@ class MySQL(Dialect):
this = self._parse_id_var(any_token=False)
index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text
- schema = self._parse_schema()
+ expressions = self._parse_wrapped_csv(self._parse_ordered)
options = []
while True:
@@ -478,9 +481,6 @@ class MySQL(Dialect):
elif self._match_text_seq("ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
- elif self._match_text_seq("ENGINE_ATTRIBUTE"):
- self._match(TokenType.EQ)
- opt = exp.IndexConstraintOption(engine_attr=self._parse_string())
elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"):
self._match(TokenType.EQ)
opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string())
@@ -495,7 +495,7 @@ class MySQL(Dialect):
return self.expression(
exp.IndexColumnConstraint,
this=this,
- schema=schema,
+ expressions=expressions,
kind=kind,
index_type=index_type,
options=options,
@@ -617,6 +617,39 @@ class MySQL(Dialect):
return self.expression(exp.Chr, **kwargs)
+ def _parse_group_concat(self) -> t.Optional[exp.Expression]:
+ def concat_exprs(
+ node: t.Optional[exp.Expression], exprs: t.List[exp.Expression]
+ ) -> exp.Expression:
+ if isinstance(node, exp.Distinct) and len(node.expressions) > 1:
+ concat_exprs = [
+ self.expression(exp.Concat, expressions=node.expressions, safe=True)
+ ]
+ node.set("expressions", concat_exprs)
+ return node
+ if len(exprs) == 1:
+ return exprs[0]
+ return self.expression(exp.Concat, expressions=args, safe=True)
+
+ args = self._parse_csv(self._parse_lambda)
+
+ if args:
+ order = args[-1] if isinstance(args[-1], exp.Order) else None
+
+ if order:
+ # Order By is the last (or only) expression in the list and has consumed the 'expr' before it,
+ # remove 'expr' from exp.Order and add it back to args
+ args[-1] = order.this
+ order.set("this", concat_exprs(order.this, args))
+
+ this = order or concat_exprs(args[0], args)
+ else:
+ this = None
+
+ separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None
+
+ return self.expression(exp.GroupConcat, this=this, separator=separator)
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = None
@@ -630,6 +663,7 @@ class MySQL(Dialect):
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
+ SUPPORTS_TO_NUMBER = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -637,9 +671,9 @@ class MySQL(Dialect):
exp.DateDiff: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
),
- exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
+ exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
+ exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
exp.Day: _remove_ts_or_ds_to_date(),
exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
@@ -672,7 +706,7 @@ class MySQL(Dialect):
exp.TimeFromParts: rename_func("MAKETIME"),
exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"),
exp.TimestampDiff: lambda self, e: self.func(
- "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this
+ "TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
@@ -682,9 +716,10 @@ class MySQL(Dialect):
),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
- exp.TsOrDsAdd: _date_add_sql("ADD"),
+ exp.TsOrDsAdd: date_add_sql("ADD"),
exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.Week: _remove_ts_or_ds_to_date(),
exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
exp.Year: _remove_ts_or_ds_to_date(),
@@ -751,11 +786,6 @@ class MySQL(Dialect):
result = f"{result} UNSIGNED"
return result
- def xor_sql(self, expression: exp.Xor) -> str:
- if expression.expressions:
- return self.expressions(expression, sep=" XOR ")
- return super().xor_sql(expression)
-
def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str:
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index bccdad0..e038400 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import (
build_formatted_time,
no_ilike_sql,
rename_func,
+ to_number_with_nls_param,
trim_sql,
)
from sqlglot.helper import seq_get
@@ -246,6 +247,7 @@ class Oracle(Dialect):
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.ToNumber: to_number_with_nls_param,
exp.Trim: trim_sql,
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index b53ae07..11398ed 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -278,6 +278,7 @@ class Postgres(Dialect):
"REVOKE": TokenType.COMMAND,
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
+ "NAME": TokenType.NAME,
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
@@ -356,6 +357,16 @@ class Postgres(Dialect):
JSON_ARROWS_REQUIRE_JSON_TYPE = True
+ COLUMN_OPERATORS = {
+ **parser.Parser.COLUMN_OPERATORS,
+ TokenType.ARROW: lambda self, this, path: build_json_extract_path(
+ exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
+ )([this, path]),
+ TokenType.DARROW: lambda self, this, path: build_json_extract_path(
+ exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE
+ )([this, path]),
+ }
+
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
while True:
if not self._match(TokenType.L_PAREN):
@@ -484,6 +495,7 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
+ exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 3649bd2..25bba96 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -22,10 +22,13 @@ from sqlglot.dialects.dialect import (
rename_func,
right_to_substring_sql,
struct_extract_sql,
+ str_position_sql,
timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_add_cast,
+ unit_to_str,
)
+from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
@@ -93,14 +96,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate)
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
expression = ts_or_ds_add_cast(expression)
- unit = exp.Literal.string(expression.text("unit") or "DAY")
+ unit = unit_to_str(expression)
return self.func("DATE_ADD", unit, expression.expression, expression.this)
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
- unit = exp.Literal.string(expression.text("unit") or "DAY")
+ unit = unit_to_str(expression)
return self.func("DATE_DIFF", unit, expr, this)
@@ -196,6 +199,7 @@ class Presto(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
TABLESAMPLE_SIZE_IS_PERCENT = True
+ LOG_BASE_FIRST: t.Optional[bool] = None
# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
@@ -289,6 +293,7 @@ class Presto(Dialect):
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
+ SUPPORTS_TO_NUMBER = False
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
@@ -323,6 +328,7 @@ class Presto(Dialect):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
+ exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression),
@@ -339,19 +345,19 @@ class Presto(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "DAY"),
+ unit_to_str(e),
_to_int(e.expression),
e.this,
),
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this
+ "DATE_DIFF", unit_to_str(e), e.expression, e.this
),
exp.DateStrToDate: datestrtodate_sql,
exp.DateToDi: lambda self,
e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)",
exp.DateSub: lambda self, e: self.func(
"DATE_ADD",
- exp.Literal.string(e.text("unit") or "DAY"),
+ unit_to_str(e),
_to_int(e.expression * -1),
e.this,
),
@@ -397,13 +403,10 @@ class Presto(Dialect):
]
),
exp.SortArray: _no_sort_array,
- exp.StrPosition: rename_func("STRPOS"),
+ exp.StrPosition: lambda self, e: str_position_sql(self, e, generate_instance=True),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
- exp.StrToUnix: lambda self, e: self.func(
- "TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e))
- ),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
@@ -436,6 +439,22 @@ class Presto(Dialect):
exp.Xor: bool_xor_sql,
}
+ def strtounix_sql(self, expression: exp.StrToUnix) -> str:
+ # Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
+ # To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
+ # timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
+ # which seems to be using the same time mapping as Hive, as per:
+ # https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
+ value_as_text = exp.cast(expression.this, "text")
+ parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
+ parse_with_tz = self.func(
+ "PARSE_DATETIME",
+ value_as_text,
+ self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
+ )
+ coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
+ return self.func("TO_UNIXTIME", coalesced)
+
def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
return self.func(
@@ -481,8 +500,7 @@ class Presto(Dialect):
return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
def interval_sql(self, expression: exp.Interval) -> str:
- unit = self.sql(expression, "unit")
- if expression.this and unit.startswith("WEEK"):
+ if expression.this and expression.text("unit").upper().startswith("WEEK"):
return f"({expression.this.name} * INTERVAL '7' DAY)"
return super().interval_sql(expression)
diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py
new file mode 100644
index 0000000..3005753
--- /dev/null
+++ b/sqlglot/dialects/prql.py
@@ -0,0 +1,109 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp, parser, tokens
+from sqlglot.dialects.dialect import Dialect
+from sqlglot.tokens import TokenType
+
+
+class PRQL(Dialect):
+ class Tokenizer(tokens.Tokenizer):
+ IDENTIFIERS = ["`"]
+ QUOTES = ["'", '"']
+
+ SINGLE_TOKENS = {
+ **tokens.Tokenizer.SINGLE_TOKENS,
+ "=": TokenType.ALIAS,
+ "'": TokenType.QUOTE,
+ '"': TokenType.QUOTE,
+ "`": TokenType.IDENTIFIER,
+ "#": TokenType.COMMENT,
+ }
+
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ }
+
+ class Parser(parser.Parser):
+ TRANSFORM_PARSERS = {
+ "DERIVE": lambda self, query: self._parse_selection(query),
+ "SELECT": lambda self, query: self._parse_selection(query, append=False),
+ "TAKE": lambda self, query: self._parse_take(query),
+ }
+
+ def _parse_statement(self) -> t.Optional[exp.Expression]:
+ expression = self._parse_expression()
+ expression = expression if expression else self._parse_query()
+ return expression
+
+ def _parse_query(self) -> t.Optional[exp.Query]:
+ from_ = self._parse_from()
+
+ if not from_:
+ return None
+
+ query = exp.select("*").from_(from_, copy=False)
+
+ while self._match_texts(self.TRANSFORM_PARSERS):
+ query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)
+
+ return query
+
+ def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
+ if self._match(TokenType.L_BRACE):
+ selects = self._parse_csv(self._parse_expression)
+
+ if not self._match(TokenType.R_BRACE, expression=query):
+ self.raise_error("Expecting }")
+ else:
+ expression = self._parse_expression()
+ selects = [expression] if expression else []
+
+ projections = {
+ select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
+ for select in query.selects
+ }
+
+ selects = [
+ select.transform(
+ lambda s: (projections[s.name].copy() if s.name in projections else s)
+ if isinstance(s, exp.Column)
+ else s,
+ copy=False,
+ )
+ for select in selects
+ ]
+
+ return query.select(*selects, append=append, copy=False)
+
+ def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
+ num = self._parse_number() # TODO: TAKE for ranges a..b
+ return query.limit(num) if num else None
+
+ def _parse_expression(self) -> t.Optional[exp.Expression]:
+ if self._next and self._next.token_type == TokenType.ALIAS:
+ alias = self._parse_id_var(True)
+ self._match(TokenType.ALIAS)
+ return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
+ return self._parse_conjunction()
+
+ def _parse_table(
+ self,
+ schema: bool = False,
+ joins: bool = False,
+ alias_tokens: t.Optional[t.Collection[TokenType]] = None,
+ parse_bracket: bool = False,
+ is_db_reference: bool = False,
+ ) -> t.Optional[exp.Expression]:
+ return self._parse_table_parts()
+
+ def _parse_from(
+ self, joins: bool = False, skip_from_token: bool = False
+ ) -> t.Optional[exp.From]:
+ if not skip_from_token and not self._match(TokenType.FROM):
+ return None
+
+ return self.expression(
+ exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
+ )
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 0db87ec..1f0c411 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -92,23 +92,6 @@ class Redshift(Postgres):
return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
- def _parse_types(
- self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
- ) -> t.Optional[exp.Expression]:
- this = super()._parse_types(
- check_func=check_func, schema=schema, allow_identifiers=allow_identifiers
- )
-
- if (
- isinstance(this, exp.DataType)
- and this.is_type("varchar")
- and this.expressions
- and this.expressions[0].this == exp.column("MAX")
- ):
- this.set("expressions", [exp.var("MAX")])
-
- return this
-
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
@@ -153,6 +136,7 @@ class Redshift(Postgres):
NVL2_SUPPORTED = True
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = False
+ MULTI_ARG_DISTINCT = True
TYPE_MAPPING = {
**Postgres.Generator.TYPE_MAPPING,
@@ -187,9 +171,13 @@ class Redshift(Postgres):
),
exp.SortKeyProperty: lambda self,
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
+ exp.StartsWith: lambda self,
+ e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'",
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
+ exp.UnixToTime: lambda self,
+ e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')",
}
# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 20fdfb7..73a9166 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
var_map_sql,
)
-from sqlglot.expressions import Literal
-from sqlglot.helper import flatten, is_int, seq_get
+from sqlglot.helper import flatten, is_float, is_int, seq_get
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
@@ -29,33 +28,35 @@ if t.TYPE_CHECKING:
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
-def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]:
- if len(args) == 2:
- first_arg, second_arg = args
- if second_arg.is_string:
- # case: <string_expr> [ , <format> ]
- return build_formatted_time(exp.StrToTime, "snowflake")(args)
- return exp.UnixToTime(this=first_arg, scale=second_arg)
+def _build_datetime(
+ name: str, kind: exp.DataType.Type, safe: bool = False
+) -> t.Callable[[t.List], exp.Func]:
+ def _builder(args: t.List) -> exp.Func:
+ value = seq_get(args, 0)
+
+ if isinstance(value, exp.Literal):
+ int_value = is_int(value.this)
- from sqlglot.optimizer.simplify import simplify_literals
+ # Converts calls like `TO_TIME('01:02:03')` into casts
+ if len(args) == 1 and value.is_string and not int_value:
+ return exp.cast(value, kind)
- # The first argument might be an expression like 40 * 365 * 86400, so we try to
- # reduce it using `simplify_literals` first and then check if it's a Literal.
- first_arg = seq_get(args, 0)
- if not isinstance(simplify_literals(first_arg, root=True), Literal):
- # case: <variant_expr> or other expressions such as columns
- return exp.TimeStrToTime.from_arg_list(args)
+ # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special
+ # cases so we can transpile them, since they're relatively common
+ if kind == exp.DataType.Type.TIMESTAMP:
+ if int_value:
+ return exp.UnixToTime(this=value, scale=seq_get(args, 1))
+ if not is_float(value.this):
+ return build_formatted_time(exp.StrToTime, "snowflake")(args)
- if first_arg.is_string:
- if is_int(first_arg.this):
- # case: <integer>
- return exp.UnixToTime.from_arg_list(args)
+ if len(args) == 2 and kind == exp.DataType.Type.DATE:
+ formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args)
+ formatted_exp.set("safe", safe)
+ return formatted_exp
- # case: <date_expr>
- return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args)
+ return exp.Anonymous(this=name, expressions=args)
- # case: <numeric_expr>
- return exp.UnixToTime.from_arg_list(args)
+ return _builder
def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
@@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff:
)
+def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
+ def _builder(args: t.List) -> E:
+ return expr_type(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=_map_date_part(seq_get(args, 0)),
+ )
+
+ return _builder
+
+
# https://docs.snowflake.com/en/sql-reference/functions/div0
def _build_if_from_div0(args: t.List) -> exp.If:
cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
@@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If:
return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0))
-def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str:
- if expression.is_type("array"):
- return "ARRAY"
- elif expression.is_type("map"):
- return "OBJECT"
- return self.datatype_sql(expression)
-
-
def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str:
flag = expression.text("flag")
@@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression:
return expression
+def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression:
+ assert isinstance(expression, exp.Create)
+
+ def _flatten_structured_type(expression: exp.DataType) -> exp.DataType:
+ if expression.this in exp.DataType.NESTED_TYPES:
+ expression.set("expressions", None)
+ return expression
+
+ props = expression.args.get("properties")
+ if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)):
+ for schema_expression in expression.this.expressions:
+ if isinstance(schema_expression, exp.ColumnDef):
+ column_type = schema_expression.kind
+ if isinstance(column_type, exp.DataType):
+ column_type.transform(_flatten_structured_type, copy=False)
+
+ return expression
+
+
class Snowflake(Dialect):
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
@@ -312,7 +335,13 @@ class Snowflake(Dialect):
class Parser(parser.Parser):
IDENTIFY_PIVOT_STRINGS = True
+ ID_VAR_TOKENS = {
+ *parser.Parser.ID_VAR_TOKENS,
+ TokenType.MATCH_CONDITION,
+ }
+
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW}
+ TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION)
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -327,17 +356,13 @@ class Snowflake(Dialect):
end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)),
step=seq_get(args, 2),
),
- "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"BITXOR": binary_from_function(exp.BitwiseXor),
"BIT_XOR": binary_from_function(exp.BitwiseXor),
"BOOLXOR": binary_from_function(exp.Xor),
"CONVERT_TIMEZONE": _build_convert_timezone,
+ "DATE": _build_datetime("DATE", exp.DataType.Type.DATE),
"DATE_TRUNC": _date_trunc_to_time,
- "DATEADD": lambda args: exp.DateAdd(
- this=seq_get(args, 2),
- expression=seq_get(args, 1),
- unit=_map_date_part(seq_get(args, 0)),
- ),
+ "DATEADD": _build_date_time_add(exp.DateAdd),
"DATEDIFF": _build_datediff,
"DIV0": _build_if_from_div0,
"FLATTEN": exp.Explode.from_arg_list,
@@ -349,17 +374,34 @@ class Snowflake(Dialect):
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
+ "MEDIAN": lambda args: exp.PercentileCont(
+ this=seq_get(args, 0), expression=exp.Literal.number(0.5)
+ ),
"NULLIFZERO": _build_if_from_nullifzero,
"OBJECT_CONSTRUCT": _build_object_construct,
"REGEXP_REPLACE": _build_regexp_replace,
"REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
+ "TIMEADD": _build_date_time_add(exp.TimeAdd),
"TIMEDIFF": _build_datediff,
+ "TIMESTAMPADD": _build_date_time_add(exp.DateAdd),
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
- "TO_TIMESTAMP": _build_to_timestamp,
+ "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True),
+ "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE),
+ "TO_NUMBER": lambda args: exp.ToNumber(
+ this=seq_get(args, 0),
+ format=seq_get(args, 1),
+ precision=seq_get(args, 2),
+ scale=seq_get(args, 3),
+ ),
+ "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME),
+ "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP),
+ "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ),
+ "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP),
+ "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ),
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _build_if_from_zeroifnull,
}
@@ -377,7 +419,6 @@ class Snowflake(Dialect):
**parser.Parser.RANGE_PARSERS,
TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny),
TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny),
- TokenType.COLON: lambda self, this: self._parse_colon_get_path(this),
}
ALTER_PARSERS = {
@@ -434,35 +475,35 @@ class Snowflake(Dialect):
SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"}
- def _parse_colon_get_path(
- self: parser.Parser, this: t.Optional[exp.Expression]
- ) -> t.Optional[exp.Expression]:
- while True:
- path = self._parse_bitwise() or self._parse_var(any_token=True)
+ def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ this = super()._parse_column_ops(this)
+
+ casts = []
+ json_path = []
+
+ while self._match(TokenType.COLON):
+ path = super()._parse_column_ops(self._parse_field(any_token=True))
# The cast :: operator has a lower precedence than the extraction operator :, so
# we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH
- if isinstance(path, exp.Cast):
- target_type = path.to
+ while isinstance(path, exp.Cast):
+ casts.append(path.to)
path = path.this
- else:
- target_type = None
- if isinstance(path, exp.Expression):
- path = exp.Literal.string(path.sql(dialect="snowflake"))
+ if path:
+ json_path.append(path.sql(dialect="snowflake", copy=False))
- # The extraction operator : is left-associative
+ if json_path:
this = self.expression(
- exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path)
+ exp.JSONExtract,
+ this=this,
+ expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))),
)
- if target_type:
- this = exp.cast(this, target_type)
+ while casts:
+ this = self.expression(exp.Cast, this=this, to=casts.pop())
- if not self._match(TokenType.COLON):
- break
-
- return self._parse_range(this)
+ return this
# https://docs.snowflake.com/en/sql-reference/functions/date_part.html
# https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts
@@ -663,6 +704,7 @@ class Snowflake(Dialect):
"EXCLUDE": TokenType.EXCEPT,
"ILIKE ANY": TokenType.ILIKE_ANY,
"LIKE ANY": TokenType.LIKE_ANY,
+ "MATCH_CONDITION": TokenType.MATCH_CONDITION,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
"NCHAR VARYING": TokenType.VARCHAR,
@@ -703,6 +745,7 @@ class Snowflake(Dialect):
LIMIT_ONLY_LITERALS = True
JSON_KEY_VALUE_PAIR_SEP = ","
INSERT_OVERWRITE = " OVERWRITE INTO"
+ STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -711,15 +754,14 @@ class Snowflake(Dialect):
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
- exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
exp.BitwiseXor: rename_func("BITXOR"),
+ exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]),
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.DateStrToDate: datestrtodate_sql,
- exp.DataType: _datatype_sql,
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@@ -769,6 +811,7 @@ class Snowflake(Dialect):
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.Stuff: rename_func("INSERT"),
+ exp.TimeAdd: date_delta_sql("TIMEADD"),
exp.TimestampDiff: lambda self, e: self.func(
"TIMESTAMPDIFF", e.unit, e.expression, e.this
),
@@ -783,6 +826,9 @@ class Snowflake(Dialect):
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
+ exp.TsOrDsToDate: lambda self, e: self.func(
+ "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e)
+ ),
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
@@ -797,6 +843,8 @@ class Snowflake(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.NESTED: "OBJECT",
+ exp.DataType.Type.STRUCT: "OBJECT",
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
@@ -811,6 +859,37 @@ class Snowflake(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ UNSUPPORTED_VALUES_EXPRESSIONS = {
+ exp.Struct,
+ }
+
+ def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str:
+ if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS):
+ values_as_table = False
+
+ return super().values_sql(expression, values_as_table=values_as_table)
+
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ expressions = expression.expressions
+ if (
+ expressions
+ and expression.is_type(*exp.DataType.STRUCT_TYPES)
+ and any(isinstance(field_type, exp.DataType) for field_type in expressions)
+ ):
+ # The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ]
+ return "OBJECT"
+
+ return super().datatype_sql(expression)
+
+ def tonumber_sql(self, expression: exp.ToNumber) -> str:
+ return self.func(
+ "TO_NUMBER",
+ expression.this,
+ expression.args.get("format"),
+ expression.args.get("precision"),
+ expression.args.get("scale"),
+ )
+
def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
milli = expression.args.get("milli")
if milli is not None:
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 20c0fce..88b5ddc 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import exp
-from sqlglot.dialects.dialect import rename_func
+from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
from sqlglot.helper import seq_get
@@ -78,6 +78,8 @@ class Spark(Spark2):
return this
class Generator(Spark2.Generator):
+ SUPPORTS_TO_NUMBER = True
+
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
@@ -100,7 +102,7 @@ class Spark(Spark2):
e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
- "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
+ "DATEADD", unit_to_var(e), e.expression, e.this
),
exp.TryCast: lambda self, e: (
self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
@@ -117,11 +119,10 @@ class Spark(Spark2):
return self.function_fallback_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
- unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
- if unit:
- return self.func("DATEDIFF", unit, start, end)
+ if expression.unit:
+ return self.func("DATEDIFF", unit_to_var(expression), start, end)
return self.func("DATEDIFF", end, start)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 63eae6e..069916f 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
pivot_column_names,
rename_func,
trim_sql,
+ unit_to_str,
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get
@@ -203,6 +204,7 @@ class Spark2(Hive):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.AtTimeZone: lambda self, e: self.func(
"FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
),
@@ -218,7 +220,7 @@ class Spark2(Hive):
]
),
exp.DateFromParts: rename_func("MAKE_DATE"),
- exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
@@ -241,9 +243,7 @@ class Spark2(Hive):
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.Trim: trim_sql,
exp.UnixToTime: _unix_to_time_sql,
exp.VariancePop: rename_func("VAR_POP"),
@@ -252,7 +252,6 @@ class Spark2(Hive):
[transforms.remove_within_group_for_percentiles]
),
}
- TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 2b17ff9..ef7d9aa 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -33,6 +33,14 @@ def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> st
return arrow_json_extract_sql(self, expression)
+def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr:
+ if len(args) == 1:
+ args.append(exp.CurrentTimestamp())
+ if len(args) == 2:
+ return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0])
+ return exp.Anonymous(this="STRFTIME", expressions=args)
+
+
def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
@@ -82,6 +90,7 @@ class SQLite(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
+ "STRFTIME": _build_strftime,
}
STRING_ALIASES = True
@@ -93,6 +102,7 @@ class SQLite(Dialect):
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
SUPPORTS_TABLE_ALIAS_COLUMNS = False
+ SUPPORTS_TO_NUMBER = False
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
@@ -151,7 +161,9 @@ class SQLite(Dialect):
),
exp.TableSample: no_tablesample_sql,
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
+ exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this),
exp.TryCast: no_trycast_sql,
+ exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"),
}
# SQLite doesn't generally support CREATE TABLE .. properties
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 12ac600..5691f58 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
build_timestamp_trunc,
rename_func,
+ unit_to_str,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
@@ -39,15 +40,13 @@ class StarRocks(MySQL):
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.DateDiff: lambda self, e: self.func(
- "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression
+ "DATE_DIFF", unit_to_str(e), e.this, e.expression
),
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.RegexpLike: rename_func("REGEXP"),
exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)),
- exp.TimestampTrunc: lambda self, e: self.func(
- "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
- ),
+ exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)),
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index b736918..40feb67 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -5,6 +5,8 @@ from sqlglot.dialects.dialect import Dialect, rename_func
class Tableau(Dialect):
+ LOG_BASE_FIRST = False
+
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = [("[", "]")]
QUOTES = ["'", '"']
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 0663a1d..a65e10e 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -3,7 +3,13 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
+from sqlglot.dialects.dialect import (
+ Dialect,
+ max_or_greatest,
+ min_or_least,
+ rename_func,
+ to_number_with_nls_param,
+)
from sqlglot.tokens import TokenType
@@ -206,6 +212,7 @@ class Teradata(Dialect):
exp.StrToDate: lambda self,
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
+ exp.ToNumber: to_number_with_nls_param,
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 1bbed67..457e2f0 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -7,6 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
SUPPORTS_USER_DEFINED_TYPES = False
+ LOG_BASE_FIRST = True
class Generator(Presto.Generator):
TRANSFORMS = {
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b6f491f..8e06be6 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
NormalizationStrategy,
any_value_to_max_sql,
date_delta_sql,
+ datestrtodate_sql,
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
@@ -724,6 +725,7 @@ class TSQL(Dialect):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
+ SUPPORTS_TO_NUMBER = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -760,12 +762,14 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
+ exp.ArrayToString: rename_func("STRING_AGG"),
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
@@ -808,6 +812,22 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def select_sql(self, expression: exp.Select) -> str:
+ if expression.args.get("offset"):
+ if not expression.args.get("order"):
+ # ORDER BY is required in order to use OFFSET in a query, so we use
+ # a noop order by, since we don't really care about the order.
+ # See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819
+ expression.order_by(exp.select(exp.null()).subquery(), copy=False)
+
+ limit = expression.args.get("limit")
+ if isinstance(limit, exp.Limit):
+ # TOP and OFFSET can't be combined, we need use FETCH instead of TOP
+ # we replace here because otherwise TOP would be generated in select_sql
+ limit.replace(exp.Fetch(direction="FIRST", count=limit.expression))
+
+ return super().select_sql(expression)
+
def convert_sql(self, expression: exp.Convert) -> str:
name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
return self.func(
@@ -862,12 +882,12 @@ class TSQL(Dialect):
return rename_func("DATETIMEFROMPARTS")(self, expression)
- def set_operation(self, expression: exp.Union, op: str) -> str:
+ def set_operations(self, expression: exp.Union) -> str:
limit = expression.args.get("limit")
if limit:
return self.sql(expression.limit(limit.pop(), copy=False))
- return super().set_operation(expression, op)
+ return super().set_operations(expression)
def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this