diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 40 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 34 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 43 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 72 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 28 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 58 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 6 |
16 files changed, 200 insertions, 138 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 0c2105b..6a43846 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -144,7 +144,6 @@ class BigQuery(Dialect): "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, - "CURRENT_TIME": TokenType.CURRENT_TIME, "DECLARE": TokenType.COMMAND, "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, @@ -194,7 +193,6 @@ class BigQuery(Dialect): NO_PAREN_FUNCTIONS = { **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore TokenType.CURRENT_DATETIME: exp.CurrentDatetime, - TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b553df2..b54a77d 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.errors import ParseError +from sqlglot.helper import ensure_list, seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -40,7 +41,18 @@ class ClickHouse(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore + "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg( + this=seq_get(args, 0), + time=seq_get(args, 1), + decay=seq_get(params, 0), + ), "MAP": parse_var_map, + "HISTOGRAM": lambda params, args: exp.Histogram( + this=seq_get(args, 0), bins=seq_get(params, 0) + ), + "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray( + this=seq_get(args, 0), size=seq_get(params, 0) + ), "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), @@ -113,22 +125,40 @@ class ClickHouse(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, - exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", + exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}", + exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), - exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", + exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } EXPLICIT_UNION = True def _param_args_sql( - self, expression: exp.Expression, params_name: str, args_name: str + self, + expression: exp.Expression, + param_names: str | t.List[str], + arg_names: str | t.List[str], ) -> str: - params = self.format_args(self.expressions(expression, params_name)) - args = self.format_args(self.expressions(expression, args_name)) + params = self.format_args( + *( + arg + for name in ensure_list(param_names) + for arg in ensure_list(expression.args.get(name)) + ) + ) + args = self.format_args( + *( + arg + for name in ensure_list(arg_names) + for arg in ensure_list(expression.args.get(name)) + ) + ) return f"({params})({args})" def cte_sql(self, expression: exp.CTE) -> str: diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 4ff3594..4268f1b 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -23,6 +23,7 @@ class Databricks(Spark): exp.DateDiff: generate_date_delta_with_unit_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), } + TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation PARAMETER_TOKEN = "$" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 25490cb..b267521 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -8,7 +8,7 @@ from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer +from sqlglot.tokens import Token, Tokenizer from sqlglot.trie import new_trie E = t.TypeVar("E", bound=exp.Expression) @@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect): return expression def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) + return self.parser(**opts).parse(self.tokenize(sql), sql) def parse_into( self, expression_type: exp.IntoType, sql: str, **opts ) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) + return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: return self.generator(**opts).generate(expression) @@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect): def transpile(self, sql: str, **opts) -> t.List[str]: return [self.generate(expression, **opts) for expression in self.parse(sql)] + def tokenize(self, sql: str) -> t.List[Token]: + return self.tokenizer.tokenize(sql) + @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): @@ -385,6 +388,21 @@ def parse_date_delta( return inner_func +def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: + unit = seq_get(args, 0) + this = seq_get(args, 1) + + if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): + return exp.DateTrunc(unit=unit, this=this) + return exp.TimestampTrunc(this=this, unit=unit) + + +def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: + return self.func( + "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this + ) + + def locate_to_strposition(args: t.Sequence) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), @@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str: return rename_func(name)(self, expression) +def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: + cond = expression.this + + if isinstance(expression.this, exp.Distinct): + cond = expression.this.expressions[0] + self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") + + return self.func("sum", exp.func("if", cond, 1, 0)) + + def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 208e2ab..dc0e519 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -97,6 +97,7 @@ class Drill(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 43f538c..f1d2266 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, str_to_time_sql, + timestamptrunc_sql, timestrtotime_sql, ts_or_ds_to_date_sql, ) @@ -148,6 +149,9 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DataType: _datatype_sql, exp.DateAdd: _date_add, exp.DateDiff: lambda self, e: self.func( @@ -162,6 +166,7 @@ class DuckDB(Dialect): exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, exp.RegexpExtract: _regexp_extract_sql, @@ -175,6 +180,7 @@ class DuckDB(Dialect): exp.StrToTime: str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", @@ -186,6 +192,7 @@ class DuckDB(Dialect): exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", + exp.WeekOfYear: rename_func("WEEKOFYEAR"), } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c4b8fa9..0110eee 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self, expression): +def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( @@ -47,7 +49,7 @@ def _add_date_sql(self, expression): return self.func(func, expression.this, modified_increment.this) -def _date_diff_sql(self, expression): +def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) @@ -56,21 +58,21 @@ def _date_diff_sql(self, expression): return f"{diff_sql}{multiplier_sql}" -def _array_sort(self, expression): +def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" -def _property_sql(self, expression): +def _property_sql(self: generator.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix(self, expression): +def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) -def _str_to_date(self, expression): +def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -78,7 +80,7 @@ def _str_to_date(self, expression): return f"CAST({this} AS DATE)" -def _str_to_time(self, expression): +def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -86,20 +88,22 @@ def _str_to_time(self, expression): return f"CAST({this} AS TIMESTAMP)" -def _time_format(self, expression): +def _time_format( + self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix +) -> t.Optional[str]: time_format = self.format_time(expression) if time_format == Hive.time_format: return None return time_format -def _time_to_str(self, expression): +def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) return f"DATE_FORMAT({this}, {time_format})" -def _to_date_sql(self, expression): +def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format and time_format not in (Hive.time_format, Hive.date_format): @@ -107,7 +111,7 @@ def _to_date_sql(self, expression): return f"TO_DATE({this})" -def _unnest_to_explode_sql(self, expression): +def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str: unnest = expression.this if isinstance(unnest, exp.Unnest): alias = unnest.args.get("alias") @@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression): exp.Lateral( this=udtf(this=expression), view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), + alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore ) ) for expression, column in zip(unnest.expressions, alias.columns if alias else []) @@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression): return self.join_sql(expression) -def _index_sql(self, expression): +def _index_sql(self: generator.Generator, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") columns = self.sql(expression, "columns") @@ -263,14 +267,15 @@ class Hive(Dialect): exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", } TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, - exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort, @@ -333,13 +338,19 @@ class Hive(Dialect): exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } - def with_properties(self, properties): + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: + return self.func( + "COLLECT_LIST", + expression.this.this if isinstance(expression.this, exp.Order) else expression.this, + ) + + def with_properties(self, properties: exp.Properties) -> str: return self.properties( properties, prefix=self.seg("TBLPROPERTIES"), ) - def datatype_sql(self, expression): + def datatype_sql(self, expression: exp.DataType) -> str: if ( expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) and not expression.expressions diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a831235..1e2cfa3 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -177,7 +177,7 @@ class MySQL(Dialect): "@@": TokenType.SESSION_PARAMETER, } - COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore @@ -211,7 +211,6 @@ class MySQL(Dialect): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, # type: ignore TokenType.SHOW: lambda self: self._parse_show(), - TokenType.SET: lambda self: self._parse_set(), } SHOW_PARSERS = { @@ -269,15 +268,12 @@ class MySQL(Dialect): } SET_PARSERS = { - "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + **parser.Parser.SET_PARSERS, "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), - "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), - "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "NAMES": lambda self: self._parse_set_item_names(), - "TRANSACTION": lambda self: self._parse_set_transaction(), } PROFILE_TYPES = { @@ -292,15 +288,6 @@ class MySQL(Dialect): "SWAPS", } - TRANSACTION_CHARACTERISTICS = { - "ISOLATION LEVEL REPEATABLE READ", - "ISOLATION LEVEL READ COMMITTED", - "ISOLATION LEVEL READ UNCOMMITTED", - "ISOLATION LEVEL SERIALIZABLE", - "READ WRITE", - "READ ONLY", - } - def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -354,12 +341,6 @@ class MySQL(Dialect): **{"global": global_}, ) - def _parse_var_from_options(self, options): - for option in options: - if self._match_text_seq(*option.split(" ")): - return exp.Var(this=option) - return None - def _parse_oldstyle_limit(self): limit = None offset = None @@ -372,30 +353,6 @@ class MySQL(Dialect): offset = parts[0] return offset, limit - def _default_parse_set_item(self): - return self._parse_set_item_assignment(kind=None) - - def _parse_set_item_assignment(self, kind): - if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): - return self._parse_set_transaction(global_=kind == "GLOBAL") - - left = self._parse_primary() or self._parse_id_var() - if not self._match(TokenType.EQ): - self.raise_error("Expected =") - right = self._parse_statement() or self._parse_id_var() - - this = self.expression( - exp.EQ, - this=left, - expression=right, - ) - - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) - def _parse_set_item_charset(self, kind): this = self._parse_string() or self._parse_id_var() @@ -418,18 +375,6 @@ class MySQL(Dialect): kind="NAMES", ) - def _parse_set_transaction(self, global_=False): - self._match_text_seq("TRANSACTION") - characteristics = self._parse_csv( - lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) - ) - return self.expression( - exp.SetItem, - expressions=characteristics, - kind="TRANSACTION", - **{"global": global_}, - ) - class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -523,16 +468,3 @@ class MySQL(Dialect): limit_offset = f"{offset}, {limit}" if offset else limit return f" LIMIT {limit_offset}" return "" - - def setitem_sql(self, expression): - kind = self.sql(expression, "kind") - kind = f"{kind} " if kind else "" - this = self.sql(expression, "this") - expressions = self.expressions(expression) - collate = self.sql(expression, "collate") - collate = f" COLLATE {collate}" if collate else "" - global_ = "GLOBAL " if expression.args.get("global") else "" - return f"{global_}{kind}{this}{expressions}{collate}" - - def set_sql(self, expression): - return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index d7cbac4..5f556a5 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, str_position_sql, + timestamptrunc_sql, trim_sql, ) from sqlglot.helper import seq_get @@ -34,7 +35,7 @@ def _date_add_sql(kind): from sqlglot.optimizer.simplify import simplify this = self.sql(expression, "this") - unit = self.sql(expression, "unit") + unit = expression.args.get("unit") expression = simplify(expression.args["expression"]) if not isinstance(expression, exp.Literal): @@ -92,8 +93,7 @@ def _string_agg_sql(self, expression): this = expression.this if isinstance(this, exp.Order): if this.this: - this = this.this - this.pop() + this = this.this.pop() order = self.sql(expression.this) # Order has a leading space return f"STRING_AGG({self.format_args(this, separator)}{order})" @@ -256,6 +256,9 @@ class Postgres(Dialect): "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "GENERATE_SERIES": _generate_series, + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), unit=seq_get(args, 0) + ), } BITWISE = { @@ -311,6 +314,7 @@ class Postgres(Dialect): exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.Min: min_or_least, exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), @@ -320,6 +324,7 @@ class Postgres(Dialect): exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index aef9de3..07e8f43 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -3,12 +3,14 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + date_trunc_to_time, format_time_lambda, if_sql, no_ilike_sql, no_safe_divide_sql, rename_func, struct_extract_sql, + timestamptrunc_sql, timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL @@ -98,10 +100,16 @@ def _ts_or_ds_to_date_sql(self, expression): def _ts_or_ds_add_sql(self, expression): - this = self.sql(expression, "this") - e = self.sql(expression, "expression") - unit = self.sql(expression, "unit") or "'day'" - return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" + return self.func( + "DATE_ADD", + exp.Literal.string(expression.text("unit") or "day"), + expression.expression, + self.func( + "DATE_PARSE", + self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)), + Presto.date_format, + ), + ) def _sequence_sql(self, expression): @@ -195,6 +203,7 @@ class Presto(Dialect): ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), + "DATE_TRUNC": date_trunc_to_time, "FROM_UNIXTIME": _from_unixtime, "NOW": exp.CurrentTimestamp.from_arg_list, "STRPOS": lambda args: exp.StrPosition( @@ -237,6 +246,7 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -250,8 +260,12 @@ class Presto(Dialect): exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DataType: _datatype_sql, - exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", - exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", exp.Decode: _decode_sql, @@ -265,6 +279,7 @@ class Presto(Dialect): exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.Quantile: _quantile_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, @@ -277,6 +292,7 @@ class Presto(Dialect): exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index dc881b9..ebd5216 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -20,6 +20,11 @@ class Redshift(Postgres): class Parser(Postgres.Parser): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, # type: ignore + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -76,13 +81,16 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DateAdd: lambda self, e: self.func( + "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this + ), exp.DateDiff: lambda self, e: self.func( - "DATEDIFF", e.args.get("unit") or "day", e.expression, e.this + "DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", - exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.Matches: rename_func("DECODE"), + exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9b159a4..799e9a6 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -5,11 +5,13 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + date_trunc_to_time, datestrtodate_sql, format_time_lambda, inline_array_sql, min_or_least, rename_func, + timestamptrunc_sql, timestrtotime_sql, ts_or_ds_to_date_sql, var_map_sql, @@ -176,6 +178,7 @@ class Snowflake(Dialect): "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -186,10 +189,6 @@ class Snowflake(Dialect): expression=seq_get(args, 1), unit=seq_get(args, 0), ), - "DATE_TRUNC": lambda args: exp.DateTrunc( - unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore - this=seq_get(args, 1), - ), "DECODE": exp.Matches.from_arg_list, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, @@ -280,6 +279,8 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.LogicalOr: rename_func("BOOLOR_AGG"), + exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), @@ -287,6 +288,7 @@ class Snowflake(Dialect): "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 05ee53f..c271f6f 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -157,6 +157,7 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfYear: rename_func("DAYOFYEAR"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index ed7c741..ab78b6e 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,10 +1,11 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + count_if_to_sum, no_ilike_sql, no_tablesample_sql, no_trycast_sql, @@ -13,23 +14,6 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -# https://www.sqlite.org/lang_aggfunc.html#group_concat -def _group_concat_sql(self, expression): - this = expression.this - distinct = expression.find(exp.Distinct) - if distinct: - this = distinct.expressions[0] - distinct = "DISTINCT " - - if isinstance(expression.this, exp.Order): - self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") - if expression.this.this and not distinct: - this = expression.this.this - - separator = expression.args.get("separator") - return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" - - def _date_add_sql(self, expression): modifier = expression.expression modifier = expression.name if modifier.is_string else self.sql(modifier) @@ -78,20 +62,32 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.CountIf: count_if_to_sum, + exp.CurrentDate: lambda *_: "CURRENT_DATE", + exp.CurrentTime: lambda *_: "CURRENT_TIME", + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql, + exp.DateStrToDate: lambda self, e: self.sql(e, "this"), exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.Levenshtein: rename_func("EDITDIST3"), + exp.LogicalOr: rename_func("MAX"), + exp.LogicalAnd: rename_func("MIN"), exp.TableSample: no_tablesample_sql, - exp.DateStrToDate: lambda self, e: self.sql(e, "this"), exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, - exp.GroupConcat: _group_concat_sql, } + def cast_sql(self, expression: exp.Cast) -> str: + if expression.to.this == exp.DataType.Type.DATE: + return self.func("DATE", expression.this) + + return super().cast_sql(expression) + def datediff_sql(self, expression: exp.DateDiff) -> str: unit = expression.args.get("unit") unit = unit.name.upper() if unit else "DAY" @@ -119,16 +115,32 @@ class SQLite(Dialect): return f"CAST({sql} AS INTEGER)" - def fetch_sql(self, expression): + def fetch_sql(self, expression: exp.Fetch) -> str: return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) - def least_sql(self, expression): + # https://www.sqlite.org/lang_aggfunc.html#group_concat + def groupconcat_sql(self, expression): + this = expression.this + distinct = expression.find(exp.Distinct) + if distinct: + this = distinct.expressions[0] + distinct = "DISTINCT " + + if isinstance(expression.this, exp.Order): + self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") + if expression.this.this and not distinct: + this = expression.this.this + + separator = expression.args.get("separator") + return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + + def least_sql(self, expression: exp.Least) -> str: if len(expression.expressions) > 1: return rename_func("MIN")(self, expression) return self.expressions(expression) - def transaction_sql(self, expression): + def transaction_sql(self, expression: exp.Transaction) -> str: this = expression.this this = f" {this}" if this else "" return f"BEGIN{this} TRANSACTION" diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 01e6357..2ba1a92 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -3,9 +3,18 @@ from __future__ import annotations from sqlglot import exp from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL +from sqlglot.helper import seq_get class StarRocks(MySQL): + class Parser(MySQL.Parser): # type: ignore + FUNCTIONS = { + **MySQL.Parser.FUNCTIONS, + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), unit=seq_get(args, 0) + ), + } + class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, # type: ignore @@ -20,6 +29,9 @@ class StarRocks(MySQL): exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this + ), exp.TimeStrToDate: rename_func("TO_DATE"), exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 371e888..7b52047 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -117,14 +117,12 @@ def _string_agg_sql(self, e): if distinct: # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") - this = distinct.expressions[0] - distinct.pop() + this = distinct.pop().expressions[0] order = "" if isinstance(e.this, exp.Order): if e.this.this: - this = e.this.this - e.this.this.pop() + this = e.this.this.pop() order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space separator = e.args.get("separator") or exp.Literal.string(",") |