summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py2
-rw-r--r--sqlglot/dialects/clickhouse.py40
-rw-r--r--sqlglot/dialects/databricks.py1
-rw-r--r--sqlglot/dialects/dialect.py34
-rw-r--r--sqlglot/dialects/drill.py1
-rw-r--r--sqlglot/dialects/duckdb.py7
-rw-r--r--sqlglot/dialects/hive.py43
-rw-r--r--sqlglot/dialects/mysql.py72
-rw-r--r--sqlglot/dialects/postgres.py11
-rw-r--r--sqlglot/dialects/presto.py28
-rw-r--r--sqlglot/dialects/redshift.py12
-rw-r--r--sqlglot/dialects/snowflake.py10
-rw-r--r--sqlglot/dialects/spark.py1
-rw-r--r--sqlglot/dialects/sqlite.py58
-rw-r--r--sqlglot/dialects/starrocks.py12
-rw-r--r--sqlglot/dialects/tsql.py6
16 files changed, 200 insertions, 138 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 0c2105b..6a43846 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -144,7 +144,6 @@ class BigQuery(Dialect):
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
- "CURRENT_TIME": TokenType.CURRENT_TIME,
"DECLARE": TokenType.COMMAND,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
@@ -194,7 +193,6 @@ class BigQuery(Dialect):
NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
- TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index b553df2..b54a77d 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.errors import ParseError
+from sqlglot.helper import ensure_list, seq_get
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@@ -40,7 +41,18 @@ class ClickHouse(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
+ "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg(
+ this=seq_get(args, 0),
+ time=seq_get(args, 1),
+ decay=seq_get(params, 0),
+ ),
"MAP": parse_var_map,
+ "HISTOGRAM": lambda params, args: exp.Histogram(
+ this=seq_get(args, 0), bins=seq_get(params, 0)
+ ),
+ "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray(
+ this=seq_get(args, 0), size=seq_get(params, 0)
+ ),
"QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
"QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
"QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
@@ -113,22 +125,40 @@ class ClickHouse(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
- exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
+ exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
+ exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}",
+ exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
- exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
+ exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
+ exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
}
EXPLICIT_UNION = True
def _param_args_sql(
- self, expression: exp.Expression, params_name: str, args_name: str
+ self,
+ expression: exp.Expression,
+ param_names: str | t.List[str],
+ arg_names: str | t.List[str],
) -> str:
- params = self.format_args(self.expressions(expression, params_name))
- args = self.format_args(self.expressions(expression, args_name))
+ params = self.format_args(
+ *(
+ arg
+ for name in ensure_list(param_names)
+ for arg in ensure_list(expression.args.get(name))
+ )
+ )
+ args = self.format_args(
+ *(
+ arg
+ for name in ensure_list(arg_names)
+ for arg in ensure_list(expression.args.get(name))
+ )
+ )
return f"({params})({args})"
def cte_sql(self, expression: exp.CTE) -> str:
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 4ff3594..4268f1b 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -23,6 +23,7 @@ class Databricks(Spark):
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
}
+ TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation
PARAMETER_TOKEN = "$"
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 25490cb..b267521 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -8,7 +8,7 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
-from sqlglot.tokens import Tokenizer
+from sqlglot.tokens import Token, Tokenizer
from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
@@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect):
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
- return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
+ return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into(
self, expression_type: exp.IntoType, sql: str, **opts
) -> t.List[t.Optional[exp.Expression]]:
- return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
+ return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
@@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect):
def transpile(self, sql: str, **opts) -> t.List[str]:
return [self.generate(expression, **opts) for expression in self.parse(sql)]
+ def tokenize(self, sql: str) -> t.List[Token]:
+ return self.tokenizer.tokenize(sql)
+
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
@@ -385,6 +388,21 @@ def parse_date_delta(
return inner_func
+def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
+ unit = seq_get(args, 0)
+ this = seq_get(args, 1)
+
+ if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
+ return exp.DateTrunc(unit=unit, this=this)
+ return exp.TimestampTrunc(this=this, unit=unit)
+
+
+def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
+ return self.func(
+ "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
+ )
+
+
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
@@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
return rename_func(name)(self, expression)
+def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
+ cond = expression.this
+
+ if isinstance(expression.this, exp.Distinct):
+ cond = expression.this.expressions[0]
+ self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
+
+ return self.func("sum", exp.func("if", cond, 1, 0))
+
+
def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 208e2ab..dc0e519 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -97,6 +97,7 @@ class Drill(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
+ "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"),
"TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
}
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 43f538c..f1d2266 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
str_to_time_sql,
+ timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
@@ -148,6 +149,9 @@ class DuckDB(Dialect):
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArraySort: _array_sort_sql,
exp.ArraySum: rename_func("LIST_SUM"),
+ exp.DayOfMonth: rename_func("DAYOFMONTH"),
+ exp.DayOfWeek: rename_func("DAYOFWEEK"),
+ exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.DataType: _datatype_sql,
exp.DateAdd: _date_add,
exp.DateDiff: lambda self, e: self.func(
@@ -162,6 +166,7 @@ class DuckDB(Dialect):
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Pivot: no_pivot_sql,
exp.Properties: no_properties_sql,
exp.RegexpExtract: _regexp_extract_sql,
@@ -175,6 +180,7 @@ class DuckDB(Dialect):
exp.StrToTime: str_to_time_sql,
exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
exp.Struct: _struct_sql,
+ exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
@@ -186,6 +192,7 @@ class DuckDB(Dialect):
exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("TO_TIMESTAMP"),
exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)",
+ exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index c4b8fa9..0110eee 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
@@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _add_date_sql(self, expression):
+def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
@@ -47,7 +49,7 @@ def _add_date_sql(self, expression):
return self.func(func, expression.this, modified_increment.this)
-def _date_diff_sql(self, expression):
+def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF"
_, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1))
@@ -56,21 +58,21 @@ def _date_diff_sql(self, expression):
return f"{diff_sql}{multiplier_sql}"
-def _array_sort(self, expression):
+def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str:
if expression.expression:
self.unsupported("Hive SORT_ARRAY does not support a comparator")
return f"SORT_ARRAY({self.sql(expression, 'this')})"
-def _property_sql(self, expression):
+def _property_sql(self: generator.Generator, expression: exp.Property) -> str:
return f"'{expression.name}'={self.sql(expression, 'value')}"
-def _str_to_unix(self, expression):
+def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str:
return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression))
-def _str_to_date(self, expression):
+def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -78,7 +80,7 @@ def _str_to_date(self, expression):
return f"CAST({this} AS DATE)"
-def _str_to_time(self, expression):
+def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format not in (Hive.time_format, Hive.date_format):
@@ -86,20 +88,22 @@ def _str_to_time(self, expression):
return f"CAST({this} AS TIMESTAMP)"
-def _time_format(self, expression):
+def _time_format(
+ self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix
+) -> t.Optional[str]:
time_format = self.format_time(expression)
if time_format == Hive.time_format:
return None
return time_format
-def _time_to_str(self, expression):
+def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
return f"DATE_FORMAT({this}, {time_format})"
-def _to_date_sql(self, expression):
+def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format and time_format not in (Hive.time_format, Hive.date_format):
@@ -107,7 +111,7 @@ def _to_date_sql(self, expression):
return f"TO_DATE({this})"
-def _unnest_to_explode_sql(self, expression):
+def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
unnest = expression.this
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
@@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression):
exp.Lateral(
this=udtf(this=expression),
view=True,
- alias=exp.TableAlias(this=alias.this, columns=[column]),
+ alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
)
)
for expression, column in zip(unnest.expressions, alias.columns if alias else [])
@@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression):
return self.join_sql(expression)
-def _index_sql(self, expression):
+def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
@@ -263,14 +267,15 @@ class Hive(Dialect):
exp.DataType.Type.TEXT: "STRING",
exp.DataType.Type.DATETIME: "TIMESTAMP",
exp.DataType.Type.VARBINARY: "BINARY",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
}
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
+ **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.ArrayAgg: rename_func("COLLECT_LIST"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort,
@@ -333,13 +338,19 @@ class Hive(Dialect):
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
}
- def with_properties(self, properties):
+ def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
+ return self.func(
+ "COLLECT_LIST",
+ expression.this.this if isinstance(expression.this, exp.Order) else expression.this,
+ )
+
+ def with_properties(self, properties: exp.Properties) -> str:
return self.properties(
properties,
prefix=self.seg("TBLPROPERTIES"),
)
- def datatype_sql(self, expression):
+ def datatype_sql(self, expression: exp.DataType) -> str:
if (
expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
and not expression.expressions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index a831235..1e2cfa3 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -177,7 +177,7 @@ class MySQL(Dialect):
"@@": TokenType.SESSION_PARAMETER,
}
- COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
+ COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW}
class Parser(parser.Parser):
FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore
@@ -211,7 +211,6 @@ class MySQL(Dialect):
STATEMENT_PARSERS = {
**parser.Parser.STATEMENT_PARSERS, # type: ignore
TokenType.SHOW: lambda self: self._parse_show(),
- TokenType.SET: lambda self: self._parse_set(),
}
SHOW_PARSERS = {
@@ -269,15 +268,12 @@ class MySQL(Dialect):
}
SET_PARSERS = {
- "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
+ **parser.Parser.SET_PARSERS,
"PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
"PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
- "SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
- "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
- "TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@@ -292,15 +288,6 @@ class MySQL(Dialect):
"SWAPS",
}
- TRANSACTION_CHARACTERISTICS = {
- "ISOLATION LEVEL REPEATABLE READ",
- "ISOLATION LEVEL READ COMMITTED",
- "ISOLATION LEVEL READ UNCOMMITTED",
- "ISOLATION LEVEL SERIALIZABLE",
- "READ WRITE",
- "READ ONLY",
- }
-
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@@ -354,12 +341,6 @@ class MySQL(Dialect):
**{"global": global_},
)
- def _parse_var_from_options(self, options):
- for option in options:
- if self._match_text_seq(*option.split(" ")):
- return exp.Var(this=option)
- return None
-
def _parse_oldstyle_limit(self):
limit = None
offset = None
@@ -372,30 +353,6 @@ class MySQL(Dialect):
offset = parts[0]
return offset, limit
- def _default_parse_set_item(self):
- return self._parse_set_item_assignment(kind=None)
-
- def _parse_set_item_assignment(self, kind):
- if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
- return self._parse_set_transaction(global_=kind == "GLOBAL")
-
- left = self._parse_primary() or self._parse_id_var()
- if not self._match(TokenType.EQ):
- self.raise_error("Expected =")
- right = self._parse_statement() or self._parse_id_var()
-
- this = self.expression(
- exp.EQ,
- this=left,
- expression=right,
- )
-
- return self.expression(
- exp.SetItem,
- this=this,
- kind=kind,
- )
-
def _parse_set_item_charset(self, kind):
this = self._parse_string() or self._parse_id_var()
@@ -418,18 +375,6 @@ class MySQL(Dialect):
kind="NAMES",
)
- def _parse_set_transaction(self, global_=False):
- self._match_text_seq("TRANSACTION")
- characteristics = self._parse_csv(
- lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
- )
- return self.expression(
- exp.SetItem,
- expressions=characteristics,
- kind="TRANSACTION",
- **{"global": global_},
- )
-
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@@ -523,16 +468,3 @@ class MySQL(Dialect):
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
-
- def setitem_sql(self, expression):
- kind = self.sql(expression, "kind")
- kind = f"{kind} " if kind else ""
- this = self.sql(expression, "this")
- expressions = self.expressions(expression)
- collate = self.sql(expression, "collate")
- collate = f" COLLATE {collate}" if collate else ""
- global_ = "GLOBAL " if expression.args.get("global") else ""
- return f"{global_}{kind}{this}{expressions}{collate}"
-
- def set_sql(self, expression):
- return f"SET {self.expressions(expression)}"
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index d7cbac4..5f556a5 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
str_position_sql,
+ timestamptrunc_sql,
trim_sql,
)
from sqlglot.helper import seq_get
@@ -34,7 +35,7 @@ def _date_add_sql(kind):
from sqlglot.optimizer.simplify import simplify
this = self.sql(expression, "this")
- unit = self.sql(expression, "unit")
+ unit = expression.args.get("unit")
expression = simplify(expression.args["expression"])
if not isinstance(expression, exp.Literal):
@@ -92,8 +93,7 @@ def _string_agg_sql(self, expression):
this = expression.this
if isinstance(this, exp.Order):
if this.this:
- this = this.this
- this.pop()
+ this = this.this.pop()
order = self.sql(expression.this) # Order has a leading space
return f"STRING_AGG({self.format_args(this, separator)}{order})"
@@ -256,6 +256,9 @@ class Postgres(Dialect):
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
"GENERATE_SERIES": _generate_series,
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
}
BITWISE = {
@@ -311,6 +314,7 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: _date_diff_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Min: min_or_least,
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@@ -320,6 +324,7 @@ class Postgres(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
+ exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index aef9de3..07e8f43 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -3,12 +3,14 @@ from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ date_trunc_to_time,
format_time_lambda,
if_sql,
no_ilike_sql,
no_safe_divide_sql,
rename_func,
struct_extract_sql,
+ timestamptrunc_sql,
timestrtotime_sql,
)
from sqlglot.dialects.mysql import MySQL
@@ -98,10 +100,16 @@ def _ts_or_ds_to_date_sql(self, expression):
def _ts_or_ds_add_sql(self, expression):
- this = self.sql(expression, "this")
- e = self.sql(expression, "expression")
- unit = self.sql(expression, "unit") or "'day'"
- return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+ return self.func(
+ "DATE_ADD",
+ exp.Literal.string(expression.text("unit") or "day"),
+ expression.expression,
+ self.func(
+ "DATE_PARSE",
+ self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)),
+ Presto.date_format,
+ ),
+ )
def _sequence_sql(self, expression):
@@ -195,6 +203,7 @@ class Presto(Dialect):
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
+ "DATE_TRUNC": date_trunc_to_time,
"FROM_UNIXTIME": _from_unixtime,
"NOW": exp.CurrentTimestamp.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
@@ -237,6 +246,7 @@ class Presto(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
+ **transforms.ELIMINATE_QUALIFY, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
@@ -250,8 +260,12 @@ class Presto(Dialect):
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DataType: _datatype_sql,
- exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
- exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
+ exp.DateAdd: lambda self, e: self.func(
+ "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ ),
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this
+ ),
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
exp.Decode: _decode_sql,
@@ -265,6 +279,7 @@ class Presto(Dialect):
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.LogicalAnd: rename_func("BOOL_AND"),
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
@@ -277,6 +292,7 @@ class Presto(Dialect):
exp.StructExtract: struct_extract_sql,
exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'",
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
+ exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: timestrtotime_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index dc881b9..ebd5216 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -20,6 +20,11 @@ class Redshift(Postgres):
class Parser(Postgres.Parser):
FUNCTIONS = {
**Postgres.Parser.FUNCTIONS, # type: ignore
+ "DATEADD": lambda args: exp.DateAdd(
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
+ ),
"DATEDIFF": lambda args: exp.DateDiff(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@@ -76,13 +81,16 @@ class Redshift(Postgres):
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
+ exp.DateAdd: lambda self, e: self.func(
+ "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this
+ ),
exp.DateDiff: lambda self, e: self.func(
- "DATEDIFF", e.args.get("unit") or "day", e.expression, e.this
+ "DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this
),
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
- exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
+ exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
}
# Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 9b159a4..799e9a6 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -5,11 +5,13 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ date_trunc_to_time,
datestrtodate_sql,
format_time_lambda,
inline_array_sql,
min_or_least,
rename_func,
+ timestamptrunc_sql,
timestrtotime_sql,
ts_or_ds_to_date_sql,
var_map_sql,
@@ -176,6 +178,7 @@ class Snowflake(Dialect):
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
+ "DATE_TRUNC": date_trunc_to_time,
"DATEADD": lambda args: exp.DateAdd(
this=seq_get(args, 2),
expression=seq_get(args, 1),
@@ -186,10 +189,6 @@ class Snowflake(Dialect):
expression=seq_get(args, 1),
unit=seq_get(args, 0),
),
- "DATE_TRUNC": lambda args: exp.DateTrunc(
- unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
- this=seq_get(args, 1),
- ),
"DECODE": exp.Matches.from_arg_list,
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
@@ -280,6 +279,8 @@ class Snowflake(Dialect):
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
+ exp.LogicalOr: rename_func("BOOLOR_AGG"),
+ exp.LogicalAnd: rename_func("BOOLAND_AGG"),
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
@@ -287,6 +288,7 @@ class Snowflake(Dialect):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 05ee53f..c271f6f 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -157,6 +157,7 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
exp.LogicalOr: rename_func("BOOL_OR"),
+ exp.LogicalAnd: rename_func("BOOL_AND"),
exp.DayOfWeek: rename_func("DAYOFWEEK"),
exp.DayOfMonth: rename_func("DAYOFMONTH"),
exp.DayOfYear: rename_func("DAYOFYEAR"),
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index ed7c741..ab78b6e 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -1,10 +1,11 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
+ count_if_to_sum,
no_ilike_sql,
no_tablesample_sql,
no_trycast_sql,
@@ -13,23 +14,6 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
-# https://www.sqlite.org/lang_aggfunc.html#group_concat
-def _group_concat_sql(self, expression):
- this = expression.this
- distinct = expression.find(exp.Distinct)
- if distinct:
- this = distinct.expressions[0]
- distinct = "DISTINCT "
-
- if isinstance(expression.this, exp.Order):
- self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
- if expression.this.this and not distinct:
- this = expression.this.this
-
- separator = expression.args.get("separator")
- return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
-
-
def _date_add_sql(self, expression):
modifier = expression.expression
modifier = expression.name if modifier.is_string else self.sql(modifier)
@@ -78,20 +62,32 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
+ **transforms.ELIMINATE_QUALIFY, # type: ignore
+ exp.CountIf: count_if_to_sum,
+ exp.CurrentDate: lambda *_: "CURRENT_DATE",
+ exp.CurrentTime: lambda *_: "CURRENT_TIME",
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql,
+ exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
+ exp.LogicalOr: rename_func("MAX"),
+ exp.LogicalAnd: rename_func("MIN"),
exp.TableSample: no_tablesample_sql,
- exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
- exp.GroupConcat: _group_concat_sql,
}
+ def cast_sql(self, expression: exp.Cast) -> str:
+ if expression.to.this == exp.DataType.Type.DATE:
+ return self.func("DATE", expression.this)
+
+ return super().cast_sql(expression)
+
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = expression.args.get("unit")
unit = unit.name.upper() if unit else "DAY"
@@ -119,16 +115,32 @@ class SQLite(Dialect):
return f"CAST({sql} AS INTEGER)"
- def fetch_sql(self, expression):
+ def fetch_sql(self, expression: exp.Fetch) -> str:
return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
- def least_sql(self, expression):
+ # https://www.sqlite.org/lang_aggfunc.html#group_concat
+ def groupconcat_sql(self, expression):
+ this = expression.this
+ distinct = expression.find(exp.Distinct)
+ if distinct:
+ this = distinct.expressions[0]
+ distinct = "DISTINCT "
+
+ if isinstance(expression.this, exp.Order):
+ self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.")
+ if expression.this.this and not distinct:
+ this = expression.this.this
+
+ separator = expression.args.get("separator")
+ return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
+
+ def least_sql(self, expression: exp.Least) -> str:
if len(expression.expressions) > 1:
return rename_func("MIN")(self, expression)
return self.expressions(expression)
- def transaction_sql(self, expression):
+ def transaction_sql(self, expression: exp.Transaction) -> str:
this = expression.this
this = f" {this}" if this else ""
return f"BEGIN{this} TRANSACTION"
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 01e6357..2ba1a92 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -3,9 +3,18 @@ from __future__ import annotations
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL
+from sqlglot.helper import seq_get
class StarRocks(MySQL):
+ class Parser(MySQL.Parser): # type: ignore
+ FUNCTIONS = {
+ **MySQL.Parser.FUNCTIONS,
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1), unit=seq_get(args, 0)
+ ),
+ }
+
class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING, # type: ignore
@@ -20,6 +29,9 @@ class StarRocks(MySQL):
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimestampTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
+ ),
exp.TimeStrToDate: rename_func("TO_DATE"),
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 371e888..7b52047 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -117,14 +117,12 @@ def _string_agg_sql(self, e):
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
- this = distinct.expressions[0]
- distinct.pop()
+ this = distinct.pop().expressions[0]
order = ""
if isinstance(e.this, exp.Order):
if e.this.this:
- this = e.this.this
- e.this.this.pop()
+ this = e.this.this.pop()
order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
separator = e.args.get("separator") or exp.Literal.string(",")