diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 22 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 8 |
11 files changed, 59 insertions, 14 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index a3869c6..ccdd1c9 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, inline_array_sql, + min_or_least, no_ilike_sql, rename_func, timestrtotime_sql, @@ -232,6 +233,7 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Min: min_or_least, exp.Select: transforms.preprocess( [_unqualify_unnest], transforms.delegate("select_sql") ), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6939705..25490cb 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -407,6 +407,11 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" +def min_or_least(self: Generator, expression: exp.Min) -> str: + name = "LEAST" if expression.expressions else "MIN" + return rename_func(name)(self, expression) + + def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index c2755cd..db79d86 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_properties_sql, no_safe_divide_sql, - no_tablesample_sql, rename_func, str_position_sql, str_to_time_sql, @@ -155,7 +154,6 @@ class DuckDB(Dialect): exp.StrToTime: str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, - exp.TableSample: no_tablesample_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", @@ -179,3 +177,6 @@ class DuckDB(Dialect): **generator.Generator.STAR_MAPPING, "except": "EXCLUDE", } + + def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: + return super().tablesample_sql(expression, seed_prefix="REPEATABLE") diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 44cd875..faed1cf 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, locate_to_strposition, + min_or_least, no_ilike_sql, no_recursive_cte_sql, no_safe_divide_sql, @@ -291,6 +292,7 @@ class Hive(Dialect): exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: var_map_sql, + exp.Min: min_or_least, exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index b1e20bd..3531f59 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -4,6 +4,7 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, locate_to_strposition, + min_or_least, no_ilike_sql, no_paren_current_date_sql, no_tablesample_sql, @@ -179,7 +180,7 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore @@ -441,6 +442,7 @@ class MySQL(Dialect): exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, + exp.Min: min_or_least, exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, exp.DateAdd: _date_add_sql("ADD"), diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 35076db..7e7902c 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, format_time_lambda, + min_or_least, no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, @@ -229,6 +230,7 @@ class Postgres(Dialect): "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, + "RETURNING": TokenType.RETURNING, "REVOKE": TokenType.COMMAND, "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, @@ -296,6 +298,7 @@ class Postgres(Dialect): exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), + exp.Min: min_or_least, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b4268e6..22ef51c 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -53,6 +53,7 @@ class Redshift(Postgres): "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, + "TOP": TokenType.TOP, "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 4a090c2..6413f6d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + min_or_least, rename_func, timestrtotime_sql, ts_or_ds_to_date_sql, @@ -116,10 +117,16 @@ def _div0_to_if(args): # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull def _zeroifnull_to_if(args): - cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null()) + cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) +# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull +def _nullifzero_to_if(args): + cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) + return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) + + def _datatype_sql(self, expression): if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" @@ -167,6 +174,11 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), @@ -180,6 +192,7 @@ class Snowflake(Dialect): "DECODE": exp.Matches.from_arg_list, "OBJECT_CONSTRUCT": parser.parse_var_map, "ZEROIFNULL": _zeroifnull_to_if, + "NULLIFZERO": _nullifzero_to_if, } FUNCTION_PARSERS = { @@ -254,6 +267,7 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" INTEGER_DIVISION = False + MATCHED_BY_SOURCE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -278,6 +292,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Min: min_or_least, } TYPE_MAPPING = { @@ -343,11 +358,10 @@ class Snowflake(Dialect): expression. This might not be true in a case where the same column name can be sourced from another table that can properly quote but should be true in most cases. """ - values_expressions = expression.find_all(exp.Values) values_identifiers = set( flatten( - v.args.get("alias", exp.Alias()).args.get("columns", []) - for v in values_expressions + (v.args.get("alias") or exp.Alias()).args.get("columns", []) + for v in expression.find_all(exp.Values) ) ) if values_identifiers: diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 86603b5..fb99d49 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,10 +13,6 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -def _fetch_sql(self, expression): - return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) - - # https://www.sqlite.org/lang_aggfunc.html#group_concat def _group_concat_sql(self, expression): this = expression.this @@ -94,9 +90,17 @@ class SQLite(Dialect): exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, exp.GroupConcat: _group_concat_sql, - exp.Fetch: _fetch_sql, } + def fetch_sql(self, expression): + return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) + + def least_sql(self, expression): + if len(expression.expressions) > 1: + return rename_func("MIN")(self, expression) + + return self.expressions(expression) + def transaction_sql(self, expression): this = expression.this this = f" {this}" if this else "" diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 415681c..06d5c6c 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.dialect import Dialect, min_or_least from sqlglot.tokens import TokenType @@ -126,6 +126,11 @@ class Teradata(Dialect): exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, } + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.Min: min_or_least, + } + def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: return f"PARTITION BY {self.sql(expression, 'this')}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b9f932b..d07b083 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -4,7 +4,12 @@ import re import typing as t from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func +from sqlglot.dialects.dialect import ( + Dialect, + min_or_least, + parse_date_delta, + rename_func, +) from sqlglot.expressions import DataType from sqlglot.helper import seq_get from sqlglot.time import format_time @@ -433,6 +438,7 @@ class TSQL(Dialect): exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, + exp.Min: min_or_least, } TRANSFORMS.pop(exp.ReturnsProperty) |