diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 61 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 37 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 64 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 63 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 32 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 36 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 85 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 17 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 59 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 54 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/tableau.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 21 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 37 |
19 files changed, 439 insertions, 183 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 701377b..1a88654 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, no_ilike_sql, + parse_date_delta_with_interval, rename_func, timestrtotime_sql, ts_or_ds_to_date_sql, @@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType E = t.TypeVar("E", bound=exp.Expression) -def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: - def func(args): - interval = seq_get(args, 1) - return expression_class( - this=seq_get(args, 0), - expression=interval.this, - unit=interval.args.get("unit"), - ) - - return func - - def _date_add_sql( data_type: str, kind: str ) -> t.Callable[[generator.Generator, exp.Expression], str]: @@ -142,6 +131,7 @@ class BigQuery(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "ANY TYPE": TokenType.VARIANT, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, @@ -155,14 +145,19 @@ class BigQuery(Dialect): KEYWORDS.pop("DIV") class Parser(parser.Parser): + PREFIXED_PIVOT_COLUMNS = True + + LOG_BASE_FIRST = False + LOG_DEFAULTS_TO_LN = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore this=seq_get(args, 0), ), - "DATE_ADD": _date_add(exp.DateAdd), - "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), + "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( @@ -174,12 +169,12 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), - "TIME_ADD": _date_add(exp.TimeAdd), - "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), - "DATE_SUB": _date_add(exp.DateSub), - "DATETIME_SUB": _date_add(exp.DatetimeSub), - "TIME_SUB": _date_add(exp.TimeSub), - "TIMESTAMP_SUB": _date_add(exp.TimestampSub), + "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), + "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), + "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), "PARSE_TIMESTAMP": lambda args: exp.StrToTime( this=seq_get(args, 1), format=seq_get(args, 0) ), @@ -209,14 +204,17 @@ class BigQuery(Dialect): PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, # type: ignore "NOT DETERMINISTIC": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + exp.StabilityProperty, this=exp.Literal.string("VOLATILE") ), } - LOG_BASE_FIRST = False - LOG_DEFAULTS_TO_LN = True - class Generator(generator.Generator): + EXPLICIT_UNION = True + INTERVAL_ALLOWS_PLURAL_FORM = False + JOIN_HINTS = False + TABLE_HINTS = False + LIMIT_FETCH = "LIMIT" + TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore @@ -236,9 +234,7 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.Max: max_or_greatest, exp.Min: min_or_least, - exp.Select: transforms.preprocess( - [_unqualify_unnest], transforms.delegate("select_sql") - ), + exp.Select: transforms.preprocess([_unqualify_unnest]), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -253,7 +249,7 @@ class BigQuery(Dialect): exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), - exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" + exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", exp.RegexpLike: rename_func("REGEXP_CONTAINS"), @@ -261,6 +257,7 @@ class BigQuery(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore + exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", exp.DataType.Type.BIGINT: "INT64", exp.DataType.Type.BOOLEAN: "BOOL", exp.DataType.Type.CHAR: "STRING", @@ -272,17 +269,19 @@ class BigQuery(Dialect): exp.DataType.Type.NVARCHAR: "STRING", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.VARIANT: "ANY TYPE", } + PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - EXPLICIT_UNION = True - LIMIT_FETCH = "LIMIT" - def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) if isinstance(first_arg, exp.Subqueryable): diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b06462c..e91b0bf 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -144,6 +144,13 @@ class ClickHouse(Dialect): exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + + JOIN_HINTS = False + TABLE_HINTS = False EXPLICIT_UNION = True def _param_args_sql( diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2f93ee7..138f26c 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType class Databricks(Spark): class Parser(Spark.Parser): + LOG_DEFAULTS_TO_LN = True + FUNCTIONS = { **Spark.Parser.FUNCTIONS, "DATEADD": parse_date_delta(exp.DateAdd), @@ -16,13 +18,17 @@ class Databricks(Spark): "DATEDIFF": parse_date_delta(exp.DateDiff), } - LOG_DEFAULTS_TO_LN = True + FACTOR = { + **Spark.Parser.FACTOR, + TokenType.COLON: exp.JSONExtract, + } class Generator(Spark.Generator): TRANSFORMS = { **Spark.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, + exp.JSONExtract: lambda self, e: self.binary(e, ":"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), } TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 839589d..19c6f73 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str: return "" +def no_comment_column_constraint_sql( + self: Generator, expression: exp.CommentColumnConstraint +) -> str: + self.unsupported("CommentColumnConstraint unsupported") + return "" + + def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") @@ -379,15 +386,35 @@ def parse_date_delta( ) -> t.Callable[[t.Sequence], E]: def inner_func(args: t.Sequence) -> E: unit_based = len(args) == 3 - this = seq_get(args, 2) if unit_based else seq_get(args, 0) - expression = seq_get(args, 1) if unit_based else seq_get(args, 1) - unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore - return exp_class(this=this, expression=expression, unit=unit) + this = args[2] if unit_based else seq_get(args, 0) + unit = args[0] if unit_based else exp.Literal.string("DAY") + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + return exp_class(this=this, expression=seq_get(args, 1), unit=unit) return inner_func +def parse_date_delta_with_interval( + expression_class: t.Type[E], +) -> t.Callable[[t.Sequence], t.Optional[E]]: + def func(args: t.Sequence) -> t.Optional[E]: + if len(args) < 2: + return None + + interval = args[1] + expression = interval.this + if expression and expression.is_string: + expression = exp.Literal.number(expression.this) + + return expression_class( + this=args[0], + expression=expression, + unit=exp.Literal.string(interval.text("unit")), + ) + + return func + + def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index a33aadc..d7e2d88 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -104,6 +104,9 @@ class Drill(Dialect): LOG_DEFAULTS_TO_LN = True class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.INT: "INTEGER", @@ -120,6 +123,7 @@ class Drill(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TRANSFORMS = { diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index c034208..9454db6 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, datestrtodate_sql, format_time_lambda, + no_comment_column_constraint_sql, no_pivot_sql, no_properties_sql, no_safe_divide_sql, @@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add(self, expression): - this = expression.args.get("this") + this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" @@ -139,6 +140,8 @@ class DuckDB(Dialect): } class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { @@ -150,6 +153,7 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), + exp.CommentColumnConstraint: no_comment_column_constraint_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -213,6 +217,11 @@ class DuckDB(Dialect): "except": "EXCLUDE", } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + LIMIT_FETCH = "LIMIT" def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c39656e..6746fcf 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) - modified_increment = ( - int(expression.text("expression")) * multiplier - if expression.expression.is_number - else expression.expression - ) - modified_increment = exp.Literal.number(modified_increment) - return self.func(func, expression.this, modified_increment.this) + + if isinstance(expression, exp.DateSub): + multiplier *= -1 + + if expression.expression.is_number: + modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) + else: + modified_increment = expression.expression + if multiplier != 1: + modified_increment = exp.Mul( # type: ignore + this=modified_increment, expression=exp.Literal.number(multiplier) + ) + + return self.func(func, expression.this, modified_increment) def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: @@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str return f"TO_DATE({this})" -def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str: - unnest = expression.this - if isinstance(unnest, exp.Unnest): - alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode - return "".join( - self.sql( - exp.Lateral( - this=udtf(this=expression), - view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore - ) - ) - for expression, column in zip(unnest.expressions, alias.columns if alias else []) - ) - return self.join_sql(expression) - - def _index_sql(self: generator.Generator, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") @@ -195,6 +184,7 @@ class Hive(Dialect): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] ENCODE = "utf-8" + IDENTIFIER_CAN_START_WITH_DIGIT = True KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -217,9 +207,8 @@ class Hive(Dialect): "BD": "DECIMAL", } - IDENTIFIER_CAN_START_WITH_DIGIT = True - class Parser(parser.Parser): + LOG_DEFAULTS_TO_LN = True STRICT_CAST = False FUNCTIONS = { @@ -273,9 +262,13 @@ class Hive(Dialect): ), } - LOG_DEFAULTS_TO_LN = True - class Generator(generator.Generator): + LIMIT_FETCH = "LIMIT" + TABLESAMPLE_WITH_METHOD = False + TABLESAMPLE_SIZE_IS_PERCENT = True + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", @@ -289,6 +282,9 @@ class Hive(Dialect): **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.Select: transforms.preprocess( + [transforms.eliminate_qualify, transforms.unnest_to_explode] + ), exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), @@ -298,13 +294,13 @@ class Hive(Dialect): exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateSub: _add_date_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}", + exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, - exp.Join: _unnest_to_explode_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONFormat: rename_func("TO_JSON"), @@ -354,10 +350,9 @@ class Hive(Dialect): exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "LIMIT" - def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", @@ -378,4 +373,5 @@ class Hive(Dialect): expression = exp.DataType.build("text") elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index d64efbf..666e740 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, + datestrtodate_sql, + format_time_lambda, locate_to_strposition, max_or_greatest, min_or_least, @@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + parse_date_delta_with_interval, rename_func, strposition_to_locate_sql, ) @@ -76,18 +79,6 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add(expression_class): - def func(args): - interval = seq_get(args, 1) - return expression_class( - this=seq_get(args, 0), - expression=interval.this, - unit=exp.Literal.string(interval.text("unit").lower()), - ) - - return func - - def _date_add_sql(kind): def func(self, expression): this = self.sql(expression, "this") @@ -115,6 +106,7 @@ class MySQL(Dialect): "%k": "%-H", "%l": "%-I", "%T": "%H:%M:%S", + "%W": "%a", } class Tokenizer(tokens.Tokenizer): @@ -127,12 +119,13 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "CHARSET": TokenType.CHARACTER_SET, + "LONGBLOB": TokenType.LONGBLOB, "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, - "LONGBLOB": TokenType.LONGBLOB, - "START": TokenType.BEGIN, + "MEDIUMTEXT": TokenType.MEDIUMTEXT, "SEPARATOR": TokenType.SEPARATOR, + "START": TokenType.BEGIN, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -186,14 +179,15 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "DATE_ADD": _date_add(exp.DateAdd), - "DATE_SUB": _date_add(exp.DateSub), - "STR_TO_DATE": _str_to_date, - "LOCATE": locate_to_strposition, + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), "LEFT": lambda args: exp.Substring( this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1) ), + "LOCATE": locate_to_strposition, + "STR_TO_DATE": _str_to_date, } FUNCTION_PARSERS = { @@ -388,32 +382,36 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False + JOIN_HINTS = False + TABLE_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.CurrentDate: no_paren_current_date_sql, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.ILike: no_ilike_sql, - exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.Max: max_or_greatest, - exp.Min: min_or_least, - exp.TableSample: no_tablesample_sql, - exp.TryCast: no_trycast_sql, + exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.DateAdd: _date_add_sql("ADD"), - exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, - exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", - exp.StrToDate: _str_to_date_sql, - exp.StrToTime: _str_to_date_sql, - exp.Trim: _trim_sql, + exp.ILike: no_ilike_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.Max: max_or_greatest, + exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.StrPosition: strposition_to_locate_sql, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_date_sql, + exp.TableSample: no_tablesample_sql, + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), + exp.Trim: _trim_sql, + exp.TryCast: no_trycast_sql, + exp.WeekOfYear: rename_func("WEEKOFYEAR"), } TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() @@ -425,6 +423,7 @@ class MySQL(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } LIMIT_FETCH = "LIMIT" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 3819b76..9ccd02e 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { - TokenType.COLUMN, - TokenType.RETURNING, -} - def _parse_xml_table(self) -> exp.XMLTable: this = self._parse_string() @@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable: if self._match_text_seq("PASSING"): # The BY VALUE keywords are optional and are provided for semantic clarity self._match_text_seq("BY", "VALUE") - passing = self._parse_csv( - lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS) - ) + passing = self._parse_csv(self._parse_column) by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") @@ -68,6 +61,8 @@ class Oracle(Dialect): } class Parser(parser.Parser): + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} + FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), @@ -78,6 +73,12 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } + TYPE_LITERAL_PARSERS = { + exp.DataType.Type.DATE: lambda self, this, _: self.expression( + exp.DateStrToDate, this=this + ) + } + def _parse_column(self) -> t.Optional[exp.Expression]: column = super()._parse_column() if column: @@ -100,6 +101,8 @@ class Oracle(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + JOIN_HINTS = False + TABLE_HINTS = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -119,6 +122,9 @@ class Oracle(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + exp.DateStrToDate: lambda self, e: self.func( + "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") + ), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -129,6 +135,12 @@ class Oracle(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", + exp.IfNull: rename_func("NVL"), + } + + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } LIMIT_FETCH = "FETCH" @@ -142,9 +154,9 @@ class Oracle(Dialect): def xmltable_sql(self, expression: exp.XMLTable) -> str: this = self.sql(expression, "this") - passing = self.expressions(expression, "passing") + passing = self.expressions(expression, key="passing") passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else "" - columns = self.expressions(expression, "columns") + columns = self.expressions(expression, key="columns") columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else "" by_ref = ( f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else "" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 31b7e45..c47ff51 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + datestrtodate_sql, format_time_lambda, max_or_greatest, min_or_least, @@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType -from sqlglot.transforms import delegate, preprocess +from sqlglot.transforms import preprocess, remove_target_from_merge DATE_DIFF_FACTOR = { "MICROSECOND": " * 1000000", @@ -239,7 +240,6 @@ class Postgres(Dialect): "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, - "UUID": TokenType.UUID, "CSTRING": TokenType.PSEUDO_TYPE, } @@ -248,18 +248,25 @@ class Postgres(Dialect): "$": TokenType.PARAMETER, } + VAR_SINGLE_TOKENS = {"$"} + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "NOW": exp.CurrentTimestamp.from_arg_list, - "TO_TIMESTAMP": _to_timestamp, - "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), - "GENERATE_SERIES": _generate_series, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), + "GENERATE_SERIES": _generate_series, + "NOW": exp.CurrentTimestamp.from_arg_list, + "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + "TO_TIMESTAMP": _to_timestamp, + } + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "DATE_PART": lambda self: self._parse_date_part(), } BITWISE = { @@ -279,8 +286,21 @@ class Postgres(Dialect): TokenType.LT_AT: binary_range_parser(exp.ArrayContained), } + def _parse_date_part(self) -> exp.Expression: + part = self._parse_type() + self._match(TokenType.COMMA) + value = self._parse_bitwise() + + if part and part.is_string: + part = exp.Var(this=part.name) + + return self.expression(exp.Extract, this=part, expression=value) + class Generator(generator.Generator): + INTERVAL_ALLOWS_PLURAL_FORM = False LOCKING_READS_SUPPORTED = True + JOIN_HINTS = False + TABLE_HINTS = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { @@ -301,7 +321,6 @@ class Postgres(Dialect): _auto_increment_to_serial, _serial_to_generated, ], - delegate("columndef_sql"), ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, @@ -312,6 +331,7 @@ class Postgres(Dialect): exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), @@ -321,6 +341,7 @@ class Postgres(Dialect): exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), + exp.Merge: preprocess([remove_target_from_merge]), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, @@ -344,4 +365,5 @@ class Postgres(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 07e8f43..489d439 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -19,20 +21,20 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self, expression): +def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str: accuracy = expression.args.get("accuracy") accuracy = ", " + self.sql(accuracy) if accuracy else "" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: sql = self.datatype_sql(expression) if expression.this == exp.DataType.Type.TIMESTAMPTZ: sql = f"{sql} WITH TIME ZONE" return sql -def _explode_to_unnest_sql(self, expression): +def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): return self.sql( exp.Join( @@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression): return self.lateral_sql(expression) -def _initcap_sql(self, expression): +def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: regex = r"(\w)(\w*)" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _decode_sql(self, expression): - _ensure_utf8(expression.args.get("charset")) +def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str: + _ensure_utf8(expression.args["charset"]) return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) -def _encode_sql(self, expression): - _ensure_utf8(expression.args.get("charset")) +def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str: + _ensure_utf8(expression.args["charset"]) return f"TO_UTF8({self.sql(expression, 'this')})" -def _no_sort_array(self, expression): +def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: @@ -70,49 +72,62 @@ def _no_sort_array(self, expression): return self.func("ARRAY_SORT", expression.this, comparator) -def _schema_sql(self, expression): +def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: if isinstance(expression.parent, exp.Property): columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" - for schema in expression.parent.find_all(exp.Schema): - if isinstance(schema.parent, exp.Property): - expression = expression.copy() - expression.expressions.extend(schema.expressions) + if expression.parent: + for schema in expression.parent.find_all(exp.Schema): + if isinstance(schema.parent, exp.Property): + expression = expression.copy() + expression.expressions.extend(schema.expressions) return self.schema_sql(expression) -def _quantile_sql(self, expression): +def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" -def _str_to_time_sql(self, expression): +def _str_to_time_sql( + self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate +) -> str: return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self, expression): +def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.time_format, Presto.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" -def _ts_or_ds_add_sql(self, expression): +def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: + this = expression.this + + if not isinstance(this, exp.CurrentDate): + this = self.func( + "DATE_PARSE", + self.func( + "SUBSTR", + this if this.is_string else exp.cast(this, "VARCHAR"), + exp.Literal.number(1), + exp.Literal.number(10), + ), + Presto.date_format, + ) + return self.func( "DATE_ADD", exp.Literal.string(expression.text("unit") or "day"), expression.expression, - self.func( - "DATE_PARSE", - self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)), - Presto.date_format, - ), + this, ) -def _sequence_sql(self, expression): +def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: start = expression.args["start"] end = expression.args["end"] step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series @@ -135,12 +150,12 @@ def _sequence_sql(self, expression): return self.func("SEQUENCE", start, end, step) -def _ensure_utf8(charset): +def _ensure_utf8(charset: exp.Literal) -> None: if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") -def _approx_percentile(args): +def _approx_percentile(args: t.Sequence) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -157,7 +172,7 @@ def _approx_percentile(args): return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args): +def _from_unixtime(args: t.Sequence) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -226,11 +241,15 @@ class Presto(Dialect): FUNCTION_PARSERS.pop("TRIM") class Generator(generator.Generator): + INTERVAL_ALLOWS_PLURAL_FORM = False + JOIN_HINTS = False + TABLE_HINTS = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TYPE_MAPPING = { @@ -246,7 +265,6 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -284,6 +302,9 @@ class Presto(Dialect): exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, + exp.Select: transforms.preprocess( + [transforms.eliminate_qualify, transforms.explode_to_unnest] + ), exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", @@ -308,7 +329,13 @@ class Presto(Dialect): exp.VariancePop: rename_func("VAR_POP"), } - def transaction_sql(self, expression): + def interval_sql(self, expression: exp.Interval) -> str: + unit = self.sql(expression, "unit") + if expression.this and unit.lower().startswith("week"): + return f"({expression.this.name} * INTERVAL '7' day)" + return super().interval_sql(expression) + + def transaction_sql(self, expression: exp.Transaction) -> str: modes = expression.args.get("modes") modes = f" {', '.join(modes)}" if modes else "" return f"START TRANSACTION{modes}" diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 63c14f4..a9c4f62 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -8,6 +8,10 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +def _json_sql(self, e) -> str: + return f'{self.sql(e, "this")}."{e.expression.name}"' + + class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { @@ -56,6 +60,7 @@ class Redshift(Postgres): "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, + "SYSDATE": TokenType.CURRENT_TIMESTAMP, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, "TOP": TokenType.TOP, @@ -63,7 +68,14 @@ class Redshift(Postgres): "VARBYTE": TokenType.VARBINARY, } + # Redshift allows # to appear as a table identifier prefix + SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy() + SINGLE_TOKENS.pop("#") + class Generator(Postgres.Generator): + LOCKING_READS_SUPPORTED = False + SINGLE_STRING_INTERVAL = True + TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BINARY: "VARBYTE", @@ -79,6 +91,7 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this ), @@ -87,12 +100,16 @@ class Redshift(Postgres): ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), + exp.JSONExtract: _json_sql, + exp.JSONExtractScalar: _json_sql, exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) TRANSFORMS.pop(exp.Pow) + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"} + def values_sql(self, expression: exp.Values) -> str: """ Converts `VALUES...` expression into a series of unions. diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 34bc3bd..0829669 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType -def _check_int(s): +def _check_int(s: str) -> bool: if s[0] in ("-", "+"): return s[1:].isdigit() return s.isdigit() # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args): +def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime.from_arg_list(args) -def _unix_to_time_sql(self, expression): +def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression): # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self): +def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: this = self._parse_var() or self._parse_type() + + if not this: + return None + self._match(TokenType.COMMA) expression = self._parse_bitwise() @@ -101,7 +105,7 @@ def _parse_date_part(self): scale = None ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix = self.expression(exp.TimeToUnix, this=ts) + to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) if scale: to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) @@ -112,7 +116,7 @@ def _parse_date_part(self): # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args): +def _div0_to_if(args: t.Sequence) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -120,18 +124,18 @@ def _div0_to_if(args): # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args): +def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args): +def _nullifzero_to_if(args: t.Sequence) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" elif expression.this == exp.DataType.Type.MAP: @@ -155,9 +159,8 @@ class Snowflake(Dialect): "MM": "%m", "mm": "%m", "DD": "%d", - "dd": "%d", - "d": "%-d", - "DY": "%w", + "dd": "%-d", + "DY": "%a", "dy": "%w", "HH24": "%H", "hh24": "%H", @@ -174,6 +177,8 @@ class Snowflake(Dialect): } class Parser(parser.Parser): + QUOTED_PIVOT_COLUMNS = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, @@ -269,9 +274,14 @@ class Snowflake(Dialect): "$": TokenType.PARAMETER, } + VAR_SINGLE_TOKENS = {"$"} + class Generator(generator.Generator): PARAMETER_TOKEN = "$" MATCHED_BY_SOURCE = False + SINGLE_STRING_INTERVAL = True + JOIN_HINTS = False + TABLE_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -287,26 +297,30 @@ class Snowflake(Dialect): ), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.If: rename_func("IFF"), - exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), - exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.LogicalOr: rename_func("BOOLOR_AGG"), + exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.Max: max_or_greatest, + exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.TimeToStr: lambda self, e: self.func( + "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) + ), + exp.TimestampTrunc: timestamptrunc_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.Max: max_or_greatest, - exp.Min: min_or_least, + exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), } TYPE_MAPPING = { @@ -322,14 +336,15 @@ class Snowflake(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.SetProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") return super().except_op(expression) - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: if not expression.args.get("distinct", False): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index c271f6f..a3e4cce 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,13 +1,15 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self, e): - kind = e.args.get("kind") +def _create_sql(self: Hive.Generator, e: exp.Create) -> str: + kind = e.args["kind"] properties = e.args.get("properties") if kind.upper() == "TABLE" and any( @@ -18,13 +20,13 @@ def _create_sql(self, e): return create_with_partitions_sql(self, e) -def _map_sql(self, expression): +def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: keys = self.sql(expression.args["keys"]) values = self.sql(expression.args["values"]) return f"MAP_FROM_ARRAYS({keys}, {values})" -def _str_to_date(self, expression): +def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.date_format: @@ -32,7 +34,7 @@ def _str_to_date(self, expression): return f"TO_DATE({this}, {time_format})" -def _unix_to_time(self, expression): +def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale is None: @@ -75,7 +77,11 @@ class Spark(Hive): length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "BOOLEAN": lambda args: exp.Cast( + this=seq_get(args, 0), to=exp.DataType.build("boolean") + ), "IIF": exp.If.from_arg_list, + "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")), "AGGREGATE": exp.Reduce.from_arg_list, "DAYOFWEEK": lambda args: exp.DayOfWeek( this=exp.TsOrDsToDate(this=seq_get(args, 0)), @@ -89,11 +95,16 @@ class Spark(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), + "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)), ), + "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "TIMESTAMP": lambda args: exp.Cast( + this=seq_get(args, 0), to=exp.DataType.build("timestamp") + ), } FUNCTION_PARSERS = { @@ -108,16 +119,43 @@ class Spark(Hive): "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), } - def _parse_add_column(self): + def _parse_add_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - def _parse_drop_column(self): + def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("DROP", "COLUMNS") and self.expression( exp.Drop, this=self._parse_schema(), kind="COLUMNS", ) + def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: + # Spark doesn't add a suffix to the pivot columns when there's a single aggregation + if len(pivot_columns) == 1: + return [""] + + names = [] + for agg in pivot_columns: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + + Moreover, function names are lowercased in order to mimic Spark's naming scheme. + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) + + return names + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, # type: ignore @@ -145,7 +183,7 @@ class Spark(Hive): exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, + exp.UnixToTime: _unix_to_time_sql, exp.Create: _create_sql, exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 4091dbb..4437f82 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType def _date_add_sql(self, expression): modifier = expression.expression - modifier = expression.name if modifier.is_string else self.sql(modifier) + modifier = modifier.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" return self.func("DATE", expression.this, modifier) @@ -38,6 +38,9 @@ class SQLite(Dialect): } class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "INTEGER", @@ -82,6 +85,11 @@ class SQLite(Dialect): exp.TryCast: no_trycast_sql, } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + LIMIT_FETCH = "LIMIT" def cast_sql(self, expression: exp.Cast) -> str: diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 2ba1a92..ff19dab 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,7 +1,11 @@ from __future__ import annotations from sqlglot import exp -from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func +from sqlglot.dialects.dialect import ( + approx_count_distinct_sql, + arrow_json_extract_sql, + rename_func, +) from sqlglot.dialects.mysql import MySQL from sqlglot.helper import seq_get @@ -10,6 +14,7 @@ class StarRocks(MySQL): class Parser(MySQL.Parser): # type: ignore FUNCTIONS = { **MySQL.Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -25,6 +30,7 @@ class StarRocks(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, # type: ignore + exp.ApproxDistinct: approx_count_distinct_sql, exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 31b1c8d..792c2b4 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -21,6 +21,9 @@ def _count_sql(self, expression): class Tableau(Dialect): class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.If: _if_sql, @@ -28,6 +31,11 @@ class Tableau(Dialect): exp.Count: _count_sql, } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 3d43793..331e105 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,14 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + max_or_greatest, + min_or_least, +) from sqlglot.tokens import TokenType @@ -115,7 +122,18 @@ class Teradata(Dialect): return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) + def _parse_cast(self, strict: bool) -> exp.Expression: + cast = t.cast(exp.Cast, super()._parse_cast(strict)) + if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT): + return format_time_lambda(exp.TimeToStr, "teradata")( + [cast.this, self._parse_string()] + ) + return cast + class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", @@ -130,6 +148,7 @@ class Teradata(Dialect): **generator.Generator.TRANSFORMS, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b8a227b..9cf56e1 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -96,6 +96,23 @@ def _parse_eomonth(args): return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) +def _parse_hashbytes(args): + kind, data = args + kind = kind.name.upper() if kind.is_string else "" + + if kind == "MD5": + args.pop(0) + return exp.MD5(this=data) + if kind in ("SHA", "SHA1"): + args.pop(0) + return exp.SHA(this=data) + if kind == "SHA2_256": + return exp.SHA2(this=data, length=exp.Literal.number(256)) + if kind == "SHA2_512": + return exp.SHA2(this=data, length=exp.Literal.number(512)) + return exp.func("HASHBYTES", *args) + + def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return self.func(func, e.text("unit"), e.expression, e.this) @@ -266,6 +283,7 @@ class TSQL(Dialect): "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, + "SYSTEM_USER": TokenType.CURRENT_USER, } # TSQL allows @, # to appear as a variable/identifier prefix @@ -287,6 +305,7 @@ class TSQL(Dialect): "EOMONTH": _parse_eomonth, "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, + "HASHBYTES": _parse_hashbytes, "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, @@ -296,6 +315,14 @@ class TSQL(Dialect): "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, + "SYSTEM_USER": exp.CurrentUser.from_arg_list, + } + + JOIN_HINTS = { + "LOOP", + "HASH", + "MERGE", + "REMOTE", } VAR_LENGTH_DATATYPES = { @@ -441,11 +468,21 @@ class TSQL(Dialect): exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, exp.Max: max_or_greatest, + exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, + exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), + exp.SHA2: lambda self, e: self.func( + "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this + ), } TRANSFORMS.pop(exp.ReturnsProperty) + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + LIMIT_FETCH = "FETCH" def offset_sql(self, expression: exp.Offset) -> str: |