From ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 7 Mar 2023 19:09:31 +0100 Subject: Merging upstream version 11.3.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 4 ++++ sqlglot/dialects/clickhouse.py | 3 +++ sqlglot/dialects/dialect.py | 5 ++--- sqlglot/dialects/duckdb.py | 6 ++++++ sqlglot/dialects/hive.py | 4 ++++ sqlglot/dialects/mysql.py | 3 +++ sqlglot/dialects/oracle.py | 17 +++++++++++++++-- sqlglot/dialects/postgres.py | 7 +------ sqlglot/dialects/snowflake.py | 41 ++++++++++++++++++++++++++++++++++++++++- sqlglot/dialects/teradata.py | 19 ++++++++++++++++++- 10 files changed, 96 insertions(+), 13 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 32b5147..a3869c6 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -213,7 +213,11 @@ class BigQuery(Dialect): ), } + INTEGER_DIVISION = False + class Generator(generator.Generator): + INTEGER_DIVISION = False + TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b553df2..a78d4db 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -56,6 +56,8 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + INTEGER_DIVISION = False + def _parse_in( self, this: t.Optional[exp.Expression], is_global: bool = False ) -> exp.Expression: @@ -94,6 +96,7 @@ class ClickHouse(Dialect): class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") + INTEGER_DIVISION = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index af36256..6939705 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -360,10 +360,9 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: if has_schema and is_partitionable: expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) - this = prop and prop.this - if prop and not isinstance(this, exp.Schema): + if prop and prop.this and not isinstance(prop.this, exp.Schema): schema = expression.this - columns = {v.name.upper() for v in this.expressions} + columns = {v.name.upper() for v in prop.this.expressions} partitions = [col for col in schema.expressions if col.name.upper() in columns] schema.set("expressions", [e for e in schema.expressions if e not in partitions]) prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6144101..c2755cd 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -83,6 +83,7 @@ class DuckDB(Dialect): ":=": TokenType.EQ, "ATTACH": TokenType.COMMAND, "CHARACTER VARYING": TokenType.VARCHAR, + "EXCLUDE": TokenType.EXCEPT, } class Parser(parser.Parser): @@ -173,3 +174,8 @@ class DuckDB(Dialect): exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } + + STAR_MAPPING = { + **generator.Generator.STAR_MAPPING, + "except": "EXCLUDE", + } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ea1191e..44cd875 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -256,7 +256,11 @@ class Hive(Dialect): ), } + INTEGER_DIVISION = False + class Generator(generator.Generator): + INTEGER_DIVISION = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 836bf3c..b1e20bd 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -300,6 +300,8 @@ class MySQL(Dialect): "READ ONLY", } + INTEGER_DIVISION = False + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -432,6 +434,7 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False + INTEGER_DIVISION = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 74baa8a..795bbeb 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -82,8 +82,17 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } + INTEGER_DIVISION = False + + def _parse_column(self) -> t.Optional[exp.Expression]: + column = super()._parse_column() + if column: + column.set("join_mark", self._match(TokenType.JOIN_MARKER)) + return column + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + INTEGER_DIVISION = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -108,6 +117,8 @@ class Oracle(Dialect): exp.Trim: trim_sql, exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), + exp.Table: lambda self, e: self.table_sql(e, sep=" "), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", exp.Substring: rename_func("SUBSTR"), @@ -139,8 +150,9 @@ class Oracle(Dialect): def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" - def table_sql(self, expression: exp.Table, sep: str = " ") -> str: - return super().table_sql(expression, sep=sep) + def column_sql(self, expression: exp.Column) -> str: + column = super().column_sql(expression) + return f"{column} (+)" if expression.args.get("join_mark") else column def xmltable_sql(self, expression: exp.XMLTable) -> str: this = self.sql(expression, "this") @@ -156,6 +168,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "(+)": TokenType.JOIN_MARKER, "COLUMNS": TokenType.COLUMN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 3507cb5..35076db 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -222,10 +222,8 @@ class Postgres(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "CHARACTER VARYING": TokenType.VARCHAR, - "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, - "GRANT": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, @@ -260,10 +258,7 @@ class Postgres(Dialect): TokenType.HASH: exp.BitwiseXor, } - FACTOR = { - **parser.Parser.FACTOR, # type: ignore - TokenType.CARET: exp.Pow, - } + FACTOR = {**parser.Parser.FACTOR, TokenType.CARET: exp.Pow} class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5931364..4a090c2 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, @@ -104,6 +106,20 @@ def _parse_date_part(self): return self.expression(exp.Extract, this=this, expression=expression) +# https://docs.snowflake.com/en/sql-reference/functions/div0 +def _div0_to_if(args): + cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) + true = exp.Literal.number(0) + false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1)) + return exp.If(this=cond, true=true, false=false) + + +# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull +def _zeroifnull_to_if(args): + cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null()) + return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) + + def _datatype_sql(self, expression): if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" @@ -150,16 +166,20 @@ class Snowflake(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), ), + "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "TO_ARRAY": exp.Array.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "DECODE": exp.Matches.from_arg_list, "OBJECT_CONSTRUCT": parser.parse_var_map, + "ZEROIFNULL": _zeroifnull_to_if, } FUNCTION_PARSERS = { @@ -193,6 +213,19 @@ class Snowflake(Dialect): ), } + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, # type: ignore + "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), + "SET": lambda self: self._parse_alter_table_set_tag(), + } + + INTEGER_DIVISION = False + + def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: + self._match_text_seq("TAG") + parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction) + return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset) + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] @@ -220,12 +253,14 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" + INTEGER_DIVISION = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), - exp.DateAdd: rename_func("DATEADD"), + exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), + exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -294,6 +329,10 @@ class Snowflake(Dialect): return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) return super().values_sql(expression) + def settag_sql(self, expression: exp.SetTag) -> str: + action = "UNSET" if expression.args.get("unset") else "SET" + return f"{action} TAG {self.expressions(expression)}" + def select_sql(self, expression: exp.Select) -> str: """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 7953bc5..415681c 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -74,6 +74,7 @@ class Teradata(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, # type: ignore + "RANGE_N": lambda self: self._parse_rangen(), "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), } @@ -105,6 +106,15 @@ class Teradata(Dialect): }, ) + def _parse_rangen(self): + this = self._parse_id_var() + self._match(TokenType.BETWEEN) + + expressions = self._parse_csv(self._parse_conjunction) + each = self._match_text_seq("EACH") and self._parse_conjunction() + + return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -114,7 +124,6 @@ class Teradata(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, - exp.VolatilityProperty: exp.Properties.Location.POST_CREATE, } def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: @@ -137,3 +146,11 @@ class Teradata(Dialect): type_sql = super().datatype_sql(expression) prefix_sql = expression.args.get("prefix") return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql + + def rangen_sql(self, expression: exp.RangeN) -> str: + this = self.sql(expression, "this") + expressions_sql = self.expressions(expression) + each_sql = self.sql(expression, "each") + each_sql = f" EACH {each_sql}" if each_sql else "" + + return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})" -- cgit v1.2.3