From 7457677bc603569692329e39a59ccb018306e2a6 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 12 Mar 2023 11:17:16 +0100 Subject: Merging upstream version 11.3.6. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/column.py | 8 +++---- sqlglot/dialects/bigquery.py | 4 ---- sqlglot/dialects/clickhouse.py | 3 --- sqlglot/dialects/databricks.py | 1 + sqlglot/dialects/drill.py | 4 ++-- sqlglot/dialects/duckdb.py | 31 ++++++++++++++++++++++-- sqlglot/dialects/hive.py | 4 ---- sqlglot/dialects/mysql.py | 4 +--- sqlglot/dialects/oracle.py | 7 +++--- sqlglot/dialects/postgres.py | 19 ++++++++++++++- sqlglot/dialects/redshift.py | 5 ++++ sqlglot/dialects/snowflake.py | 32 +++++++++++++------------ sqlglot/dialects/sqlite.py | 27 +++++++++++++++++++++ sqlglot/dialects/teradata.py | 1 + sqlglot/dialects/tsql.py | 13 +++++------ sqlglot/expressions.py | 35 +++++++++++++++++++-------- sqlglot/generator.py | 32 +++++++++++-------------- sqlglot/parser.py | 52 +++++++++++++++++++---------------------- sqlglot/tokens.py | 10 ++++++++ 20 files changed, 187 insertions(+), 107 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index a9a220c..4a30008 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -47,7 +47,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.3.3" +__version__ = "11.3.6" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f45d467..609b2a4 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -67,10 +67,10 @@ class Column: return self.binary_op(exp.Mul, other) def __truediv__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.FloatDiv, other) + return self.binary_op(exp.Div, other) def __div__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.FloatDiv, other) + return self.binary_op(exp.Div, other) def __neg__(self) -> Column: return self.unary_op(exp.Neg) @@ -85,10 +85,10 @@ class Column: return self.inverse_binary_op(exp.Mul, other) def __rdiv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.FloatDiv, other) + return self.inverse_binary_op(exp.Div, other) def __rtruediv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.FloatDiv, other) + return self.inverse_binary_op(exp.Div, other) def __rmod__(self, other: ColumnOrLiteral) -> Column: return self.inverse_binary_op(exp.Mod, other) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index ccdd1c9..0c2105b 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -214,11 +214,7 @@ class BigQuery(Dialect): ), } - INTEGER_DIVISION = False - class Generator(generator.Generator): - INTEGER_DIVISION = False - TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index a78d4db..b553df2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -56,8 +56,6 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore - INTEGER_DIVISION = False - def _parse_in( self, this: t.Optional[exp.Expression], is_global: bool = False ) -> exp.Expression: @@ -96,7 +94,6 @@ class ClickHouse(Dialect): class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") - INTEGER_DIVISION = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2e058e8..4ff3594 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -21,6 +21,7 @@ class Databricks(Spark): **Spark.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, + exp.ToChar: lambda self, e: self.function_fallback_sql(e), } PARAMETER_TOKEN = "$" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index afcf4d0..208e2ab 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -9,7 +9,6 @@ from sqlglot.dialects.dialect import ( create_with_partitions_sql, datestrtodate_sql, format_time_lambda, - no_pivot_sql, no_trycast_sql, rename_func, str_position_sql, @@ -136,16 +135,17 @@ class Drill(Dialect): exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.Pivot: no_pivot_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.StrPosition: str_position_sql, exp.StrToDate: _str_to_date, + exp.Pow: rename_func("POW"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.Var(this='DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index db79d86..43f538c 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -81,8 +81,21 @@ class DuckDB(Dialect): **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, "ATTACH": TokenType.COMMAND, - "CHARACTER VARYING": TokenType.VARCHAR, + "BINARY": TokenType.VARBINARY, + "BPCHAR": TokenType.TEXT, + "BITSTRING": TokenType.BIT, + "CHAR": TokenType.TEXT, + "CHARACTER VARYING": TokenType.TEXT, "EXCLUDE": TokenType.EXCEPT, + "INT1": TokenType.TINYINT, + "LOGICAL": TokenType.BOOLEAN, + "NUMERIC": TokenType.DOUBLE, + "SIGNED": TokenType.INT, + "STRING": TokenType.VARCHAR, + "UBIGINT": TokenType.UBIGINT, + "UINTEGER": TokenType.UINT, + "USMALLINT": TokenType.USMALLINT, + "UTINYINT": TokenType.UTINYINT, } class Parser(parser.Parser): @@ -115,6 +128,14 @@ class DuckDB(Dialect): "UNNEST": exp.Explode.from_arg_list, } + TYPE_TOKENS = { + *parser.Parser.TYPE_TOKENS, + TokenType.UBIGINT, + TokenType.UINT, + TokenType.USMALLINT, + TokenType.UTINYINT, + } + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -169,8 +190,14 @@ class DuckDB(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.CHAR: "TEXT", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.NCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", + exp.DataType.Type.UINT: "UINTEGER", + exp.DataType.Type.VARBINARY: "BLOB", + exp.DataType.Type.VARCHAR: "TEXT", } STAR_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index faed1cf..c4b8fa9 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -257,11 +257,7 @@ class Hive(Dialect): ), } - INTEGER_DIVISION = False - class Generator(generator.Generator): - INTEGER_DIVISION = False - TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 3531f59..a831235 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -301,8 +301,6 @@ class MySQL(Dialect): "READ ONLY", } - INTEGER_DIVISION = False - def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -435,7 +433,6 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False - INTEGER_DIVISION = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -446,6 +443,7 @@ class MySQL(Dialect): exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, exp.DateAdd: _date_add_sql("ADD"), + exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 795bbeb..7028a04 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -4,7 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql -from sqlglot.helper import csv +from sqlglot.helper import csv, seq_get from sqlglot.tokens import TokenType PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { @@ -75,6 +75,7 @@ class Oracle(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "DECODE": exp.Matches.from_arg_list, + "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -82,8 +83,6 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } - INTEGER_DIVISION = False - def _parse_column(self) -> t.Optional[exp.Expression]: column = super()._parse_column() if column: @@ -92,7 +91,6 @@ class Oracle(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True - INTEGER_DIVISION = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -122,6 +120,7 @@ class Oracle(Dialect): exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", exp.Substring: rename_func("SUBSTR"), + exp.ToChar: lambda self, e: self.function_fallback_sql(e), } def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 7e7902c..d7cbac4 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import ( trim_sql, ) from sqlglot.helper import seq_get +from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess @@ -219,6 +220,8 @@ class Postgres(Dialect): "~~*": TokenType.ILIKE, "~*": TokenType.IRLIKE, "~": TokenType.RLIKE, + "@>": TokenType.AT_GT, + "<@": TokenType.LT_AT, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, @@ -260,7 +263,17 @@ class Postgres(Dialect): TokenType.HASH: exp.BitwiseXor, } - FACTOR = {**parser.Parser.FACTOR, TokenType.CARET: exp.Pow} + FACTOR = { + **parser.Parser.FACTOR, + TokenType.CARET: exp.Pow, + } + + RANGE_PARSERS = { + **parser.Parser.RANGE_PARSERS, # type: ignore + TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), + TokenType.AT_GT: binary_range_parser(exp.ArrayContains), + TokenType.LT_AT: binary_range_parser(exp.ArrayContained), + } class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -299,6 +312,9 @@ class Postgres(Dialect): exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.Min: min_or_least, + exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), + exp.ArrayContains: lambda self, e: self.binary(e, "@>"), + exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, @@ -307,6 +323,7 @@ class Postgres(Dialect): exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, + exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 22ef51c..dc881b9 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -29,6 +29,8 @@ class Redshift(Postgres): "NVL": exp.Coalesce.from_arg_list, } + CONVERT_TYPE_FIRST = True + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: this = super()._parse_types(check_func=check_func) @@ -83,6 +85,9 @@ class Redshift(Postgres): exp.Matches: rename_func("DECODE"), } + # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) + TRANSFORMS.pop(exp.Pow) + def values_sql(self, expression: exp.Values) -> str: """ Converts `VALUES...` expression into a series of unions. diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 6413f6d..9b159a4 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -16,6 +16,7 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.expressions import Literal from sqlglot.helper import flatten, seq_get +from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType @@ -111,7 +112,7 @@ def _parse_date_part(self): def _div0_to_if(args): cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) - false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1)) + false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) return exp.If(this=cond, true=true, false=false) @@ -173,26 +174,33 @@ class Snowflake(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0), ), + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), ), + "DECODE": exp.Matches.from_arg_list, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "NULLIFZERO": _nullifzero_to_if, + "OBJECT_CONSTRUCT": parser.parse_var_map, + "RLIKE": exp.RegexpLike.from_arg_list, + "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_ARRAY": exp.Array.from_arg_list, + "TO_VARCHAR": exp.ToChar.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, - "ARRAY_CONSTRUCT": exp.Array.from_arg_list, - "RLIKE": exp.RegexpLike.from_arg_list, - "DECODE": exp.Matches.from_arg_list, - "OBJECT_CONSTRUCT": parser.parse_var_map, "ZEROIFNULL": _zeroifnull_to_if, - "NULLIFZERO": _nullifzero_to_if, } FUNCTION_PARSERS = { @@ -218,12 +226,8 @@ class Snowflake(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, # type: ignore - TokenType.LIKE_ANY: lambda self, this: self._parse_escape( - self.expression(exp.LikeAny, this=this, expression=self._parse_bitwise()) - ), - TokenType.ILIKE_ANY: lambda self, this: self._parse_escape( - self.expression(exp.ILikeAny, this=this, expression=self._parse_bitwise()) - ), + TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny), + TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny), } ALTER_PARSERS = { @@ -232,8 +236,6 @@ class Snowflake(Dialect): "SET": lambda self: self._parse_alter_table_set_tag(), } - INTEGER_DIVISION = False - def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: self._match_text_seq("TAG") parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction) @@ -266,7 +268,6 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" - INTEGER_DIVISION = False MATCHED_BY_SOURCE = False TRANSFORMS = { @@ -289,6 +290,7 @@ class Snowflake(Dialect): exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index fb99d49..ed7c741 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -92,6 +92,33 @@ class SQLite(Dialect): exp.GroupConcat: _group_concat_sql, } + def datediff_sql(self, expression: exp.DateDiff) -> str: + unit = expression.args.get("unit") + unit = unit.name.upper() if unit else "DAY" + + sql = f"(JULIANDAY({self.sql(expression, 'this')}) - JULIANDAY({self.sql(expression, 'expression')}))" + + if unit == "MONTH": + sql = f"{sql} / 30.0" + elif unit == "YEAR": + sql = f"{sql} / 365.0" + elif unit == "HOUR": + sql = f"{sql} * 24.0" + elif unit == "MINUTE": + sql = f"{sql} * 1440.0" + elif unit == "SECOND": + sql = f"{sql} * 86400.0" + elif unit == "MILLISECOND": + sql = f"{sql} * 86400000.0" + elif unit == "MICROSECOND": + sql = f"{sql} * 86400000000.0" + elif unit == "NANOSECOND": + sql = f"{sql} * 8640000000000.0" + else: + self.unsupported("DATEDIFF unsupported for '{unit}'.") + + return f"CAST({sql} AS INTEGER)" + def fetch_sql(self, expression): return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 06d5c6c..8bd0a0c 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -129,6 +129,7 @@ class Teradata(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.Min: min_or_least, + exp.ToChar: lambda self, e: self.function_fallback_sql(e), } def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index d07b083..371e888 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -248,7 +248,6 @@ class TSQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "BIT": TokenType.BOOLEAN, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DECLARE": TokenType.COMMAND, @@ -283,19 +282,20 @@ class TSQL(Dialect): substr=seq_get(args, 0), position=seq_get(args, 2), ), - "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), + "EOMONTH": _parse_eomonth, + "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, - "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "IIF": exp.If.from_arg_list, + "ISNULL": exp.Coalesce.from_arg_list, + "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, - "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "FORMAT": _parse_format, - "EOMONTH": _parse_eomonth, + "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), + "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, } VAR_LENGTH_DATATYPES = { @@ -421,7 +421,6 @@ class TSQL(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 085871e..0c345b3 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2656,12 +2656,17 @@ class DataType(Expression): BINARY = auto() VARBINARY = auto() INT = auto() + UINT = auto() TINYINT = auto() + UTINYINT = auto() SMALLINT = auto() + USMALLINT = auto() BIGINT = auto() + UBIGINT = auto() FLOAT = auto() DOUBLE = auto() DECIMAL = auto() + BIT = auto() BOOLEAN = auto() JSON = auto() JSONB = auto() @@ -2861,10 +2866,6 @@ class Div(Binary): pass -class FloatDiv(Binary): - pass - - class Overlaps(Binary): pass @@ -2971,6 +2972,10 @@ class Sub(Binary): pass +class ArrayOverlaps(Binary): + pass + + # Unary Expressions # (NOT a) class Unary(Expression): @@ -3135,6 +3140,11 @@ class Array(Func): is_var_len_args = True +# https://docs.snowflake.com/en/sql-reference/functions/to_char +class ToChar(Func): + arg_types = {"this": True, "format": False} + + class GenerateSeries(Func): arg_types = {"start": True, "end": True, "step": False} @@ -3156,8 +3166,12 @@ class ArrayConcat(Func): is_var_len_args = True -class ArrayContains(Func): - arg_types = {"this": True, "expression": True} +class ArrayContains(Binary, Func): + pass + + +class ArrayContained(Binary): + pass class ArrayFilter(Func): @@ -3272,6 +3286,7 @@ class DateSub(Func, TimeUnit): class DateDiff(Func, TimeUnit): + _sql_names = ["DATEDIFF", "DATE_DIFF"] arg_types = {"this": True, "expression": True, "unit": False} @@ -4861,19 +4876,19 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: from sqlglot.dialects.dialect import Dialect - args = tuple(convert(arg) for arg in args) + converted = [convert(arg) for arg in args] kwargs = {key: convert(value) for key, value in kwargs.items()} parser = Dialect.get_or_raise(dialect)().parser() from_args_list = parser.FUNCTIONS.get(name.upper()) if from_args_list: - function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore + function = from_args_list(converted) if converted else from_args_list.__self__(**kwargs) # type: ignore else: - kwargs = kwargs or {"expressions": args} + kwargs = kwargs or {"expressions": converted} function = Anonymous(this=name, **kwargs) - for error_message in function.error_messages(args): + for error_message in function.error_messages(converted): raise ValueError(error_message) return function diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 4504e95..5936649 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -59,7 +59,6 @@ class Generator: exp.DateAdd: lambda self, e: self.func( "DATE_ADD", e.this, e.expression, e.args.get("unit") ), - exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.TsOrDsAdd: lambda self, e: self.func( "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") ), @@ -109,9 +108,6 @@ class Generator: # Whether or not create function uses an AS before the RETURN CREATE_FUNCTION_RETURN_AS = True - # Whether or not to treat the division operator "/" as integer division - INTEGER_DIVISION = True - # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed MATCHED_BY_SOURCE = True @@ -1571,7 +1567,7 @@ class Generator: ) else: this = "" - unit = expression.args.get("unit") + unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" return f"INTERVAL{this}{unit}" @@ -1757,25 +1753,17 @@ class Generator: return f"{self.sql(expression, 'this')} RESPECT NULLS" def intdiv_sql(self, expression: exp.IntDiv) -> str: - div = self.binary(expression, "/") - return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT"))) + return self.sql( + exp.Cast( + this=exp.Div(this=expression.this, expression=expression.expression), + to=exp.DataType(this=exp.DataType.Type.INT), + ) + ) def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") def div_sql(self, expression: exp.Div) -> str: - div = self.binary(expression, "/") - - if not self.INTEGER_DIVISION: - return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT"))) - - return div - - def floatdiv_sql(self, expression: exp.FloatDiv) -> str: - if self.INTEGER_DIVISION: - this = exp.Cast(this=expression.this, to=exp.DataType.build("DOUBLE")) - return self.div_sql(exp.Div(this=this, expression=expression.expression)) - return self.binary(expression, "/") def overlaps_sql(self, expression: exp.Overlaps) -> str: @@ -1991,3 +1979,9 @@ class Generator: using = f"USING {self.sql(expression, 'using')}" on = f"ON {self.sql(expression, 'on')}" return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}" + + def tochar_sql(self, expression: exp.ToChar) -> str: + if expression.args.get("format"): + self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function") + + return self.sql(exp.cast(expression.this, "text")) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 894d68e..90fdade 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -31,15 +31,20 @@ def parse_var_map(args): ) +def binary_range_parser( + expr_type: t.Type[exp.Expression], +) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]: + return lambda self, this: self._parse_escape( + self.expression(expr_type, this=this, expression=self._parse_bitwise()) + ) + + class _Parser(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) - if not klass.INTEGER_DIVISION: - klass.FACTOR = {**klass.FACTOR, TokenType.SLASH: exp.FloatDiv} - return klass @@ -102,6 +107,7 @@ class Parser(metaclass=_Parser): } TYPE_TOKENS = { + TokenType.BIT, TokenType.BOOLEAN, TokenType.TINYINT, TokenType.SMALLINT, @@ -503,29 +509,15 @@ class Parser(metaclass=_Parser): RANGE_PARSERS = { TokenType.BETWEEN: lambda self, this: self._parse_between(this), - TokenType.GLOB: lambda self, this: self._parse_escape( - self.expression(exp.Glob, this=this, expression=self._parse_bitwise()) - ), - TokenType.OVERLAPS: lambda self, this: self._parse_escape( - self.expression(exp.Overlaps, this=this, expression=self._parse_bitwise()) - ), + TokenType.GLOB: binary_range_parser(exp.Glob), + TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IS: lambda self, this: self._parse_is(this), - TokenType.LIKE: lambda self, this: self._parse_escape( - self.expression(exp.Like, this=this, expression=self._parse_bitwise()) - ), - TokenType.ILIKE: lambda self, this: self._parse_escape( - self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) - ), - TokenType.IRLIKE: lambda self, this: self.expression( - exp.RegexpILike, this=this, expression=self._parse_bitwise() - ), - TokenType.RLIKE: lambda self, this: self.expression( - exp.RegexpLike, this=this, expression=self._parse_bitwise() - ), - TokenType.SIMILAR_TO: lambda self, this: self.expression( - exp.SimilarTo, this=this, expression=self._parse_bitwise() - ), + TokenType.LIKE: binary_range_parser(exp.Like), + TokenType.ILIKE: binary_range_parser(exp.ILike), + TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), + TokenType.RLIKE: binary_range_parser(exp.RegexpLike), + TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo), } PROPERTY_PARSERS = { @@ -707,7 +699,7 @@ class Parser(metaclass=_Parser): STRICT_CAST = True - INTEGER_DIVISION = True + CONVERT_TYPE_FIRST = False __slots__ = ( "error_level", @@ -2542,7 +2534,7 @@ class Parser(metaclass=_Parser): def _parse_type(self) -> t.Optional[exp.Expression]: if self._match(TokenType.INTERVAL): - return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var()) + return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field()) index = self._index type_token = self._parse_types(check_func=True) @@ -3285,15 +3277,19 @@ class Parser(metaclass=_Parser): def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to: t.Optional[exp.Expression] - this = self._parse_column() + this = self._parse_bitwise() if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) elif self._match(TokenType.COMMA): - to = self._parse_types() + to = self._parse_bitwise() else: to = None + # Swap the argument order if needed to produce the correct AST + if self.CONVERT_TYPE_FIRST: + this, to = to, this + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) def _parse_position(self, haystack_first: bool = False) -> exp.Expression: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 5f4b77d..053bbdd 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -46,10 +46,13 @@ class TokenType(AutoName): HASH_ARROW = auto() DHASH_ARROW = auto() LR_ARROW = auto() + LT_AT = auto() + AT_GT = auto() DOLLAR = auto() PARAMETER = auto() SESSION_PARAMETER = auto() NATIONAL = auto() + DAMP = auto() BLOCK_START = auto() BLOCK_END = auto() @@ -71,11 +74,16 @@ class TokenType(AutoName): BYTE_STRING = auto() # types + BIT = auto() BOOLEAN = auto() TINYINT = auto() + UTINYINT = auto() SMALLINT = auto() + USMALLINT = auto() INT = auto() + UINT = auto() BIGINT = auto() + UBIGINT = auto() FLOAT = auto() DOUBLE = auto() DECIMAL = auto() @@ -462,6 +470,7 @@ class Tokenizer(metaclass=_Tokenizer): "#>": TokenType.HASH_ARROW, "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, + "&&": TokenType.DAMP, "ALL": TokenType.ALL, "ALWAYS": TokenType.ALWAYS, "AND": TokenType.AND, @@ -630,6 +639,7 @@ class Tokenizer(metaclass=_Tokenizer): "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, "APPLY": TokenType.APPLY, "ARRAY": TokenType.ARRAY, + "BIT": TokenType.BIT, "BOOL": TokenType.BOOLEAN, "BOOLEAN": TokenType.BOOLEAN, "BYTE": TokenType.TINYINT, -- cgit v1.2.3